springDamperUserFunctionNumbaJIT.py

You can view and download this file on Github: springDamperUserFunctionNumbaJIT.py

  1#+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
  2# This is an EXUDYN example
  3#
  4# Details:  Test with user-defined load function and user-defined spring-damper function (Duffing oscillator)
  5#
  6# Author:   Johannes Gerstmayr
  7# Date:     2019-11-15
  8#
  9# Copyright:This file is part of Exudyn. Exudyn is free software. You can redistribute it and/or modify it under the terms of the Exudyn license. See 'LICENSE.txt' for more details.
 10#
 11#+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 12
 13import sys
 14sys.exudynFast = True
 15
 16from exudyn.utilities import ClearWorkspace
 17ClearWorkspace()
 18
 19import exudyn as exu
 20from exudyn.utilities import * #includes itemInterface and rigidBodyUtilities
 21import exudyn.graphics as graphics #only import if it does not conflict
 22
 23import numpy as np
 24
 25#+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 26#NUMBA PART; mainly, we need to register MainSystem mbs in numba to get user functions work
 27#import numba jit for compilation of functions:
 28# from numba import jit
 29
 30#create identity operator for replacement of jit:
 31try:
 32    from numba import jit
 33    print('running WITH JIT')
 34except: #define replacement operator
 35    print('running WITHOUT JIT')
 36    def jit(ob):
 37        return ob
 38
 39# from numba import jit, cfunc, types, njit
 40# from numba.types import float64, void, int64 #for signatures of user functions!
 41#+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 42
 43
 44# @jit
 45# def myfunc():
 46#     print("my function")
 47
 48#+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 49
 50
 51useGraphics = False #without test
 52
 53
 54SC = exu.SystemContainer()
 55mbs = SC.AddSystem()
 56exu.Print('EXUDYN version='+exu.GetVersionString())
 57
 58L=0.5
 59mass = 1.6          #mass in kg
 60spring = 4000       #stiffness of spring-damper in N/m
 61damper = 4          #damping constant in N/(m/s)
 62load0 = 80
 63
 64omega0=np.sqrt(spring/mass)
 65f0 = 0.*omega0/(2*np.pi)
 66f1 = 1.*omega0/(2*np.pi)
 67
 68exu.Print('resonance frequency = '+str(omega0))
 69tEnd = 50     #end time of simulation
 70steps = 1000000  #number of steps
 71
 72#first test without JIT:
 73
 74def sf(u,v,k,d):
 75    return 0.1*k*u+k*u**3 + 1e-3*k*u**5 + 1e-6*k*u**7+v*d
 76
 77def springForce(mbs2, t, itemIndex, u, v, k, d, offset):
 78    return sf(u,v,k,d)
 79    # x=test(mbs.systemData.GetTime()) #5 microseconds
 80    # q=mbs.systemData.GetODE2Coordinates() #5 microseconds
 81    # return 0.1*k*u+k*u**3+v*d
 82
 83#linear frequency sweep in time interval [0, t1] and frequency interval [f0,f1];
 84def Sweep(t, t1, f0, f1):
 85    k = (f1-f0)/t1
 86    return np.sin(2*np.pi*(f0+k*0.5*t)*t) #take care of factor 0.5 in k*0.5*t, in order to obtain correct frequencies!!!
 87
 88#user function for load; void replaces mbs, which then may not be used!!!
 89#most time lost due to pybind11 std::function capturing; no simple way to overcome problem at this point (avoid many function calls!)
 90#@cfunc(float64(void, float64, float64)) #possible, but does not lead to speed up
 91#@jit #not possible because of mbs not recognized by numba
 92def userLoad(mbs, t, load):
 93    #x=mbs.systemData.GetTime() #call to systemData function takes around 5us ! Cannot be optimized!
 94    #global tEnd, f0, f1 #global does not change performance
 95    return load*Sweep(t, tEnd, f0, f1) #global variable does not seem to make problems!
 96
 97#node for 3D mass point:
 98n1=mbs.AddNode(Point(referenceCoordinates = [L,0,0]))
 99
100#ground node
101nGround=mbs.AddNode(NodePointGround(referenceCoordinates = [0,0,0]))
102
103#add mass point (this is a 3D object with 3 coordinates):
104massPoint = mbs.AddObject(MassPoint(physicsMass = mass, nodeNumber = n1))
105
106#marker for ground (=fixed):
107groundMarker=mbs.AddMarker(MarkerNodeCoordinate(nodeNumber= nGround, coordinate = 0))
108#marker for springDamper for first (x-)coordinate:
109nodeMarker  =mbs.AddMarker(MarkerNodeCoordinate(nodeNumber= n1, coordinate = 0))
110
111#Spring-Damper between two marker coordinates
112oSD=mbs.AddObject(CoordinateSpringDamper(markerNumbers = [groundMarker, nodeMarker],
113                                     stiffness = spring, damping = damper,
114                                     springForceUserFunction = springForce,
115                                     ))
116
117#add load:
118loadC = mbs.AddLoad(LoadCoordinate(markerNumber = nodeMarker,
119                           load = load0,
120                           loadUserFunction=userLoad,
121                           ))
122
123mbs.Assemble()
124
125simulationSettings = exu.SimulationSettings()
126simulationSettings.solutionSettings.writeSolutionToFile = False
127simulationSettings.timeIntegration.numberOfSteps = steps
128simulationSettings.timeIntegration.endTime = tEnd
129simulationSettings.timeIntegration.newton.useModifiedNewton=True
130
131simulationSettings.timeIntegration.generalizedAlpha.spectralRadius = 1
132
133simulationSettings.displayStatistics = True
134simulationSettings.displayComputationTime = True
135simulationSettings.timeIntegration.verboseMode = 1
136
137#start solver:
138mbs.SolveDynamic(simulationSettings)
139
140#evaluate final (=current) output values
141u = mbs.GetNodeOutput(n1, exu.OutputVariableType.Position)
142exu.Print('displacement=',u[0])
143
144
145#%%+++++++++++++++++++++++++++++++++++++++++++++++++++++
146#run again with JIT included:
147
148#use jit for every time-consuming parts
149#the more complex it gets, the speedup will be larger!
150#however, this part can only contain simple structures (no mbs, no exudyn functions [but you could @jit them!])
151@jit
152def sf2(u,v,k,d):
153    return 0.1*k*u+k*u**3 + 1e-3*k*u**5 + 1e-6*k*u**7+v*d
154
155def springForce2(mbs2, t, itemIndex, u, v, k, d, offset):
156    return sf2(u,v,k,d)
157
158# jit for both sub-functions of user functions:
159mbs.SetObjectParameter(oSD, 'springForceUserFunction', springForce2)
160
161#jit gives us speedup and works out of the box:
162@jit
163def Sweep2(t, t1, f0, f1):
164    k = (f1-f0)/t1
165    return np.sin(2*np.pi*(f0+k*0.5*t)*t) #take care of factor 0.5 in k*0.5*t, in order to obtain correct frequencies!!!
166
167#user function for load; void replaces mbs, which then may not be used!!!
168# @cfunc(float64(void, float64, float64), nopython=True, fastmath=True) #possible, but does not lead to speed up
169def userLoad2(mbs, t, load):
170    return load*Sweep2(t, tEnd, f0, f1) #global variable does not seem to make problems!
171
172mbs.SetLoadParameter(loadC,'loadUserFunction', userLoad2)
173
174mbs.SolveDynamic(simulationSettings)
175
176#evaluate final (=current) output values
177u = mbs.GetNodeOutput(n1, exu.OutputVariableType.Position)
178exu.Print('JIT, displacement=',u[0])
179
180
181#performance:
182#1e6 time steps
183# no user functions:
184# tCPU=1.15 seconds
185
186# regular, Python user function for spring-damper and load:
187# tCPU=16.7 seconds
188
189# jit, Python user function for spring-damper and load:
190# tCPU=5.58 seconds (on average)
191#==>speedup of user function part: 16.7/(5.58-1.15)=4.43
192#speedup will be much larger if Python functions are larger!
193#approx. 400.000 Python function calls/second!