testGymCartpoleEnv.py

You can view and download this file on Github: testGymCartpoleEnv.py

  1#+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
  2# This is an EXUDYN example
  3#
  4# Details:  This file serves as an input to testGymCartpole.py
  5#
  6# Author:   Johannes Gerstmayr, Grzegorz Orzechowski
  7# Date:     2022-05-17
  8# Update:   2023-05-20: derive from gym.Env to ensure compatibility with newer stable-baselines3
  9#
 10# Copyright:This file is part of Exudyn. Exudyn is free software. You can redistribute it and/or modify it under the terms of the Exudyn license. See 'LICENSE.txt' for more details.
 11#
 12#+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 13
 14import exudyn as exu
 15from exudyn.utilities import * #includes itemInterface and rigidBodyUtilities
 16import exudyn.graphics as graphics #only import if it does not conflict
 17import math
 18
 19import math
 20from typing import Optional, Union
 21
 22import numpy as np
 23
 24import gym
 25from gym import logger, spaces, Env
 26from gym.error import DependencyNotInstalled
 27
 28import stable_baselines3
 29useOldGym = tuple(map(int, stable_baselines3.__version__.split('.'))) <= tuple(map(int, '1.8.0'.split('.')))
 30
 31
 32class CartPoleEnv(Env):
 33
 34    #metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 50}
 35    metadata = {"render_modes": ["human"], "render_fps": 50}
 36
 37    def __init__(self, thresholdFactor = 1., forceFactor = 1.):
 38
 39        self.SC = exu.SystemContainer()
 40        self.mbs = self.SC.AddSystem()
 41
 42
 43        #%%+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 44        #+++++++++++++++++++++++++++++++++++++++++++++++++++++
 45        #take variables from cartpole example in openAIgym
 46        self.gravity = 9.8
 47        self.masscart = 1.0
 48        self.masspole = 0.1
 49        self.total_mass = self.masspole + self.masscart
 50        self.lengthHalf = 0.5  # actually half the pole's self.length
 51        self.length = self.lengthHalf*2
 52        self.polemass_length = self.masspole * self.length
 53        self.force_mag = 10.0*forceFactor
 54        self.stepUpdateTime = 0.02  # seconds between state updates
 55        self.kinematics_integrator = "euler"
 56
 57        # Angle at which to fail the episode
 58        self.theta_threshold_radians = thresholdFactor*12 * 2 * math.pi / 360
 59        self.x_threshold = thresholdFactor*2.4
 60
 61        high = np.array(
 62            [
 63                self.x_threshold * 2,
 64                np.finfo(np.float32).max,
 65                self.theta_threshold_radians * 2,
 66                np.finfo(np.float32).max,
 67            ],
 68            dtype=np.float32,
 69        )
 70
 71        #+++++++++++++++++++++++++++++++++++++++++++++++++++++
 72        #see https://github.com/openai/gym/blob/64b4b31d8245f6972b3d37270faf69b74908a67d/gym/core.py#L16
 73        #for Env:
 74        self.action_space = spaces.Discrete(2)
 75        self.observation_space = spaces.Box(-high, high, dtype=np.float32)
 76        #+++++++++++++++++++++++++++++++++++++++++++++++++++++
 77        self.state = None
 78        self.rendererRunning=None
 79        self.useRenderer = False #turn this on if needed
 80
 81        background = graphics.CheckerBoard(point= [0,0,0], normal= [0,0,1], size=4)
 82
 83        oGround=self.mbs.AddObject(ObjectGround(referencePosition= [0,0,0],  #x-pos,y-pos,angle
 84                                           visualization=VObjectGround(graphicsData= [background])))
 85        nGround=self.mbs.AddNode(NodePointGround())
 86
 87        gCart = graphics.Brick(size=[0.5*self.length, 0.1*self.length, 0.1*self.length],
 88                                           color=graphics.color.dodgerblue)
 89        self.nCart = self.mbs.AddNode(Rigid2D(referenceCoordinates=[0,0,0]));
 90        oCart = self.mbs.AddObject(RigidBody2D(physicsMass=self.masscart,
 91                                          physicsInertia=0.1*self.masscart, #not needed
 92                                          nodeNumber=self.nCart,
 93                                          visualization=VObjectRigidBody2D(graphicsData= [gCart])))
 94
 95        gPole = graphics.Brick(size=[0.1*self.length, self.length, 0.1*self.length], color=graphics.color.red)
 96        self.nPole = self.mbs.AddNode(Rigid2D(referenceCoordinates=[0,0.5*self.length,0]));
 97        oPole = self.mbs.AddObject(RigidBody2D(physicsMass=self.masspole,
 98                                          physicsInertia=1e-6, #not included in original paper
 99                                          nodeNumber=self.nPole,
100                                          visualization=VObjectRigidBody2D(graphicsData= [gPole])))
101
102        mCartCOM = self.mbs.AddMarker(MarkerNodePosition(nodeNumber=self.nCart))
103        mPoleCOM = self.mbs.AddMarker(MarkerNodePosition(nodeNumber=self.nPole))
104        mPoleJoint = self.mbs.AddMarker(MarkerBodyPosition(bodyNumber=oPole, localPosition=[0,-0.5*self.length,0]))
105
106        mCartCoordX = self.mbs.AddMarker(MarkerNodeCoordinate(nodeNumber=self.nCart, coordinate=0))
107        mCartCoordY = self.mbs.AddMarker(MarkerNodeCoordinate(nodeNumber=self.nCart, coordinate=1))
108        mGroundNode = self.mbs.AddMarker(MarkerNodeCoordinate(nodeNumber=nGround, coordinate=0))
109
110        #gravity
111        self.mbs.AddLoad(Force(markerNumber=mCartCOM, loadVector=[0,-self.masscart*self.gravity,0]))
112        self.mbs.AddLoad(Force(markerNumber=mPoleCOM, loadVector=[0,-self.masspole*self.gravity,0]))
113
114        #control force
115        self.lControl = self.mbs.AddLoad(LoadCoordinate(markerNumber=mCartCoordX, load=1.))
116
117        #constraints:
118        self.mbs.AddObject(RevoluteJoint2D(markerNumbers=[mCartCOM, mPoleJoint]))
119        self.mbs.AddObject(CoordinateConstraint(markerNumbers=[mCartCoordY, mGroundNode]))
120
121
122
123
124        #%%++++++++++++++++++++++++
125        self.mbs.Assemble() #computes initial vector
126
127        self.simulationSettings = exu.SimulationSettings() #takes currently set values or default values
128
129
130        self.simulationSettings.timeIntegration.numberOfSteps = 1
131        self.simulationSettings.timeIntegration.endTime = 0 #will be overwritten in step
132        self.simulationSettings.timeIntegration.verboseMode = 0
133        self.simulationSettings.solutionSettings.writeSolutionToFile = False
134        #self.simulationSettings.timeIntegration.simulateInRealtime = True
135
136        self.simulationSettings.timeIntegration.newton.useModifiedNewton = True
137
138        self.SC.visualizationSettings.general.drawWorldBasis=True
139        self.SC.visualizationSettings.general.graphicsUpdateInterval = 0.01 #50Hz
140
141        self.simulationSettings.solutionSettings.solutionInformation = "Open AI gym"
142
143        self.dynamicSolver = exudyn.MainSolverImplicitSecondOrder()
144        self.dynamicSolver.InitializeSolver(self.mbs, self.simulationSettings)
145        self.dynamicSolver.SolveSteps(self.mbs, self.simulationSettings) #to initialize all data
146
147
148    def integrateStep(self, force):
149        #exudyn simulation part
150        #index 2 solver
151        self.mbs.SetLoadParameter(self.lControl, 'load', force)
152
153        #progress integration time
154        currentTime = self.simulationSettings.timeIntegration.endTime
155        self.simulationSettings.timeIntegration.startTime = currentTime
156        self.simulationSettings.timeIntegration.endTime = currentTime+self.stepUpdateTime
157
158        self.dynamicSolver.InitializeSolverInitialConditions(self.mbs, self.simulationSettings)
159        self.dynamicSolver.SolveSteps(self.mbs, self.simulationSettings)
160        currentState = self.mbs.systemData.GetSystemState() #get current values
161        self.mbs.systemData.SetSystemState(systemStateList=currentState,
162                                        configuration = exu.ConfigurationType.Initial)
163        self.mbs.systemData.SetODE2Coordinates_tt(coordinates = self.mbs.systemData.GetODE2Coordinates_tt(),
164                                                configuration = exu.ConfigurationType.Initial)
165
166
167
168
169    def step(self, action):
170        err_msg = f"{action!r} ({type(action)}) invalid"
171        assert self.action_space.contains(action), err_msg
172        assert self.state is not None, "Call reset before using step method."
173        x, x_dot, theta, theta_dot = self.state
174
175        force = self.force_mag if action == 1 else -self.force_mag
176
177        #++++++++++++++++++++++++++++++++++++++++++++++++++
178        #++++++++++++++++++++++++++++++++++++++++++++++++++
179        self.integrateStep(force)
180        #+++++++++++++++++++++++++
181        #compute some output:
182        cartPosX = self.mbs.GetNodeOutput(self.nCart, variableType=exu.OutputVariableType.Coordinates)[0]
183        poleAngle = self.mbs.GetNodeOutput(self.nPole, variableType=exu.OutputVariableType.Coordinates)[2]
184        cartPosX_t = self.mbs.GetNodeOutput(self.nCart, variableType=exu.OutputVariableType.Coordinates_t)[0]
185        poleAngle_t = self.mbs.GetNodeOutput(self.nPole, variableType=exu.OutputVariableType.Coordinates_t)[2]
186
187        #finally write updated state:
188        self.state = (cartPosX, cartPosX_t, poleAngle, poleAngle_t)
189        #++++++++++++++++++++++++++++++++++++++++++++++++++
190        #++++++++++++++++++++++++++++++++++++++++++++++++++
191
192        done = bool(
193            cartPosX < -self.x_threshold
194            or cartPosX > self.x_threshold
195            or poleAngle < -self.theta_threshold_radians
196            or poleAngle > self.theta_threshold_radians
197        )
198
199        if not done:
200            reward = 1.0
201        elif self.steps_beyond_done is None:
202            # Pole just fell!
203            self.steps_beyond_done = 0
204            reward = 1.0
205        else:
206            if self.steps_beyond_done == 0:
207                logger.warn(
208                    "You are calling 'step()' even though this "
209                    "environment has already returned done = True. You "
210                    "should always call 'reset()' once you receive 'done = "
211                    "True' -- any further steps are undefined behavior."
212                )
213            self.steps_beyond_done += 1
214            reward = 0.0
215
216        info = {}
217        terminated, truncated = done, False # since stable-baselines3 > 1.8.0 implementations terminated and truncated
218        if useOldGym:
219            return np.array(self.state, dtype=np.float32), reward, terminated, info
220        else:
221            return np.array(self.state, dtype=np.float32), reward, terminated, truncated, info
222
223
224    def reset(
225        self,
226        *,
227        seed: Optional[int] = None,
228        return_info: bool = False,
229        options: Optional[dict] = None,
230    ):
231        #super().reset(seed=seed)
232        #self.state = self.np_random.uniform(low=-0.05, high=0.05, size=(4,))
233        self.state = np.random.uniform(low=-0.05, high=0.05, size=(4,))
234        self.steps_beyond_done = None
235
236
237        #+++++++++++++++++++++++++++++++++++++++++++++
238        #set initial values:
239        #+++++++++++++++++++++++++++++++++++++++++++++
240        #set specific initial state:
241        (xCart, xCart_t, phiPole, phiPole_t) = self.state
242
243        initialValues = np.zeros(6)
244        initialValues_t = np.zeros(6)
245        initialValues[0] = xCart
246        initialValues[3+0] = xCart - 0.5*self.length * sin(phiPole)
247        initialValues[3+1] = 0.5*self.length * (cos(phiPole)-1)
248        initialValues[3+2] = phiPole
249
250        initialValues_t[0] = xCart_t
251        initialValues_t[3+0] = xCart_t - phiPole_t*0.5*self.length * cos(phiPole)
252        initialValues_t[3+1] = -0.5*self.length * sin(phiPole)  * phiPole_t
253        initialValues_t[3+2] = phiPole_t
254
255        self.mbs.systemData.SetODE2Coordinates(initialValues, exu.ConfigurationType.Initial)
256        self.mbs.systemData.SetODE2Coordinates_t(initialValues_t, exu.ConfigurationType.Initial)
257
258        self.simulationSettings.timeIntegration.endTime = 0
259        #self.dynamicSolver.FinalizeSolver(self.mbs, self.simulationSettings) #needed to update initial conditions
260        self.dynamicSolver.InitializeSolver(self.mbs, self.simulationSettings) #needed to update initial conditions
261        # self.dynamicSolver.InitializeSolverInitialConditions(self.mbs, self.simulationSettings) #needed to update initial conditions
262
263        if not return_info and useOldGym:
264            return np.array(self.state, dtype=np.float32)
265        else:
266            return np.array(self.state, dtype=np.float32), {}
267
268    def render(self, mode="human"):
269        if self.rendererRunning==None and self.useRenderer:
270            SC.renderer.Start()
271            self.rendererRunning = True
272
273    def close(self):
274        self.dynamicSolver.FinalizeSolver(self.mbs, self.simulationSettings)
275        if self.rendererRunning==True:
276            # SC.renderer.DoIdleTasks()
277            SC.renderer.Stop() #safely close rendering window!
278
279
280
281# #+++++++++++++++++++++++++++++++++++++++++++++
282# #reset:
283# self.mbs.systemData.SetODE2Coordinates(initialValues, exu.ConfigurationType.Initial)
284# self.mbs.systemData.SetODE2Coordinates_t(initialValues, exu.ConfigurationType.Initial)
285# self.mbs.systemData.SetODE2Coordinates_tt(initialValues, exu.ConfigurationType.Initial)