springDamperUserFunctionNumbaJIT.py
You can view and download this file on Github: springDamperUserFunctionNumbaJIT.py
1#+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2# This is an EXUDYN example
3#
4# Details: Test with user-defined load function and user-defined spring-damper function (Duffing oscillator)
5#
6# Author: Johannes Gerstmayr
7# Date: 2019-11-15
8#
9# Copyright:This file is part of Exudyn. Exudyn is free software. You can redistribute it and/or modify it under the terms of the Exudyn license. See 'LICENSE.txt' for more details.
10#
11#+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
12
13import sys
14sys.exudynFast = True
15
16from exudyn.utilities import ClearWorkspace
17ClearWorkspace()
18
19import exudyn as exu
20from exudyn.utilities import * #includes itemInterface and rigidBodyUtilities
21import exudyn.graphics as graphics #only import if it does not conflict
22
23import numpy as np
24
25#+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
26#NUMBA PART; mainly, we need to register MainSystem mbs in numba to get user functions work
27#import numba jit for compilation of functions:
28# from numba import jit
29
30#create identity operator for replacement of jit:
31try:
32 from numba import jit
33 print('running WITH JIT')
34except: #define replacement operator
35 print('running WITHOUT JIT')
36 def jit(ob):
37 return ob
38
39# from numba import jit, cfunc, types, njit
40# from numba.types import float64, void, int64 #for signatures of user functions!
41#+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
42
43
44# @jit
45# def myfunc():
46# print("my function")
47
48#+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
49
50
51useGraphics = False #without test
52
53
54SC = exu.SystemContainer()
55mbs = SC.AddSystem()
56exu.Print('EXUDYN version='+exu.GetVersionString())
57
58L=0.5
59mass = 1.6 #mass in kg
60spring = 4000 #stiffness of spring-damper in N/m
61damper = 4 #damping constant in N/(m/s)
62load0 = 80
63
64omega0=np.sqrt(spring/mass)
65f0 = 0.*omega0/(2*np.pi)
66f1 = 1.*omega0/(2*np.pi)
67
68exu.Print('resonance frequency = '+str(omega0))
69tEnd = 50 #end time of simulation
70steps = 1000000 #number of steps
71
72#first test without JIT:
73
74def sf(u,v,k,d):
75 return 0.1*k*u+k*u**3 + 1e-3*k*u**5 + 1e-6*k*u**7+v*d
76
77def springForce(mbs2, t, itemIndex, u, v, k, d, offset):
78 return sf(u,v,k,d)
79 # x=test(mbs.systemData.GetTime()) #5 microseconds
80 # q=mbs.systemData.GetODE2Coordinates() #5 microseconds
81 # return 0.1*k*u+k*u**3+v*d
82
83#linear frequency sweep in time interval [0, t1] and frequency interval [f0,f1];
84def Sweep(t, t1, f0, f1):
85 k = (f1-f0)/t1
86 return np.sin(2*np.pi*(f0+k*0.5*t)*t) #take care of factor 0.5 in k*0.5*t, in order to obtain correct frequencies!!!
87
88#user function for load; void replaces mbs, which then may not be used!!!
89#most time lost due to pybind11 std::function capturing; no simple way to overcome problem at this point (avoid many function calls!)
90#@cfunc(float64(void, float64, float64)) #possible, but does not lead to speed up
91#@jit #not possible because of mbs not recognized by numba
92def userLoad(mbs, t, load):
93 #x=mbs.systemData.GetTime() #call to systemData function takes around 5us ! Cannot be optimized!
94 #global tEnd, f0, f1 #global does not change performance
95 return load*Sweep(t, tEnd, f0, f1) #global variable does not seem to make problems!
96
97#node for 3D mass point:
98n1=mbs.AddNode(Point(referenceCoordinates = [L,0,0]))
99
100#ground node
101nGround=mbs.AddNode(NodePointGround(referenceCoordinates = [0,0,0]))
102
103#add mass point (this is a 3D object with 3 coordinates):
104massPoint = mbs.AddObject(MassPoint(physicsMass = mass, nodeNumber = n1))
105
106#marker for ground (=fixed):
107groundMarker=mbs.AddMarker(MarkerNodeCoordinate(nodeNumber= nGround, coordinate = 0))
108#marker for springDamper for first (x-)coordinate:
109nodeMarker =mbs.AddMarker(MarkerNodeCoordinate(nodeNumber= n1, coordinate = 0))
110
111#Spring-Damper between two marker coordinates
112oSD=mbs.AddObject(CoordinateSpringDamper(markerNumbers = [groundMarker, nodeMarker],
113 stiffness = spring, damping = damper,
114 springForceUserFunction = springForce,
115 ))
116
117#add load:
118loadC = mbs.AddLoad(LoadCoordinate(markerNumber = nodeMarker,
119 load = load0,
120 loadUserFunction=userLoad,
121 ))
122
123mbs.Assemble()
124
125simulationSettings = exu.SimulationSettings()
126simulationSettings.solutionSettings.writeSolutionToFile = False
127simulationSettings.timeIntegration.numberOfSteps = steps
128simulationSettings.timeIntegration.endTime = tEnd
129simulationSettings.timeIntegration.newton.useModifiedNewton=True
130
131simulationSettings.timeIntegration.generalizedAlpha.spectralRadius = 1
132
133simulationSettings.displayStatistics = True
134simulationSettings.displayComputationTime = True
135simulationSettings.timeIntegration.verboseMode = 1
136
137#start solver:
138mbs.SolveDynamic(simulationSettings)
139
140#evaluate final (=current) output values
141u = mbs.GetNodeOutput(n1, exu.OutputVariableType.Position)
142exu.Print('displacement=',u[0])
143
144
145#%%+++++++++++++++++++++++++++++++++++++++++++++++++++++
146#run again with JIT included:
147
148#use jit for every time-consuming parts
149#the more complex it gets, the speedup will be larger!
150#however, this part can only contain simple structures (no mbs, no exudyn functions [but you could @jit them!])
151@jit
152def sf2(u,v,k,d):
153 return 0.1*k*u+k*u**3 + 1e-3*k*u**5 + 1e-6*k*u**7+v*d
154
155def springForce2(mbs2, t, itemIndex, u, v, k, d, offset):
156 return sf2(u,v,k,d)
157
158# jit for both sub-functions of user functions:
159mbs.SetObjectParameter(oSD, 'springForceUserFunction', springForce2)
160
161#jit gives us speedup and works out of the box:
162@jit
163def Sweep2(t, t1, f0, f1):
164 k = (f1-f0)/t1
165 return np.sin(2*np.pi*(f0+k*0.5*t)*t) #take care of factor 0.5 in k*0.5*t, in order to obtain correct frequencies!!!
166
167#user function for load; void replaces mbs, which then may not be used!!!
168# @cfunc(float64(void, float64, float64), nopython=True, fastmath=True) #possible, but does not lead to speed up
169def userLoad2(mbs, t, load):
170 return load*Sweep2(t, tEnd, f0, f1) #global variable does not seem to make problems!
171
172mbs.SetLoadParameter(loadC,'loadUserFunction', userLoad2)
173
174mbs.SolveDynamic(simulationSettings)
175
176#evaluate final (=current) output values
177u = mbs.GetNodeOutput(n1, exu.OutputVariableType.Position)
178exu.Print('JIT, displacement=',u[0])
179
180
181#performance:
182#1e6 time steps
183# no user functions:
184# tCPU=1.15 seconds
185
186# regular, Python user function for spring-damper and load:
187# tCPU=16.7 seconds
188
189# jit, Python user function for spring-damper and load:
190# tCPU=5.58 seconds (on average)
191#==>speedup of user function part: 16.7/(5.58-1.15)=4.43
192#speedup will be much larger if Python functions are larger!
193#approx. 400.000 Python function calls/second!