1
+ import jax
2
+ import jax .numpy as jnp
3
+ from jsrm .systems .gvs import (
4
+ GVS ,
5
+ Joint ,
6
+ Basis
7
+ )
8
+ from jsrm .utils .gvs .custom_types import (
9
+ LinkAttributes ,
10
+ JointAttributes ,
11
+ BasisAttributes ,
12
+ )
13
+
14
+ from diffrax import Tsit5
15
+
16
+ from typing import List
17
+
18
+ import matplotlib .pyplot as plt
19
+
20
+ from functools import partial
21
+
22
+ jax .config .update ("jax_enable_x64" , True ) # double precision
23
+ jnp .set_printoptions (
24
+ threshold = jnp .inf ,
25
+ linewidth = jnp .inf ,
26
+ formatter = {"float_kind" : lambda x : "0" if x == 0 else f"{ x :.2e} " },
27
+ )
28
+
29
+ if __name__ == "__main__" :
30
+ # Define model inputs
31
+ List_links : List [LinkAttributes ] = []
32
+ List_joints : List [JointAttributes ] = []
33
+ List_basis : List [BasisAttributes ] = []
34
+ List_nGauss : List [int ] = []
35
+
36
+ link1 = LinkAttributes (
37
+ section = 'Circular' ,
38
+ E = 1e6 ,
39
+ nu = 0.5 ,
40
+ rho = 1000 ,
41
+ eta = 1e4 ,
42
+ l = 0.3 ,
43
+ r_i = 0.03 ,
44
+ r_f = 0.03
45
+ )
46
+ List_links .append (link1 )
47
+ joint1 = JointAttributes (jointtype = 'Fixed' )
48
+ List_joints .append (joint1 )
49
+ basis1 = BasisAttributes (
50
+ basistype = 'Legendre' ,
51
+ Bdof = [0 , 1 , 1 , 0 , 0 , 0 ],
52
+ Bodr = [0 , 0 , 0 , 0 , 0 , 0 ],
53
+ xi_star = [0 , 0 , 0 , 1 , 0 , 0 ]
54
+ )
55
+ List_basis .append (basis1 )
56
+ List_nGauss .append (5 ) # Number of Gauss points for the first link
57
+
58
+ link2 = LinkAttributes (
59
+ section = 'Circular' ,
60
+ E = 1e6 ,
61
+ nu = 0.5 ,
62
+ rho = 1000 ,
63
+ eta = 1e4 ,
64
+ l = 0.3 ,
65
+ r_i = 0.03 ,
66
+ r_f = 0.03
67
+ )
68
+ List_links .append (link2 )
69
+ joint2 = JointAttributes (jointtype = 'Fixed' )
70
+ # joint2 = JointAttributes(jointtype='Revolute', axis='z')
71
+ List_joints .append (joint2 )
72
+ basis2 = BasisAttributes (
73
+ basistype = 'Monomial' ,
74
+ Bdof = [1 , 1 , 0 , 0 , 0 , 0 ],
75
+ Bodr = [0 , 0 , 0 , 0 , 0 , 0 ],
76
+ xi_star = [0 , 0 , 0 , 1 , 0 , 0 ]
77
+ )
78
+ List_basis .append (basis2 )
79
+ List_nGauss .append (6 ) # Number of Gauss points for the second link
80
+
81
+ link3 = LinkAttributes (
82
+ section = 'Elliptical' , # Section type
83
+ E = 1e7 , # Young's modulus in Pascals
84
+ nu = 0.4 , # Poisson's ratio [-1, 0.5]
85
+ rho = 1050 , # Density [kg/m^3]
86
+ eta = 1e4 , # Damping coefficient
87
+ l = 0.3 , # Length in meters
88
+ a_i = 0.04 , # Initial semi-major axis in meters
89
+ a_f = 0.04 , # Final semi-major axis in meters
90
+ b_i = 0.02 , # Initial semi-minor axis in meters
91
+ b_f = 0.02 # Final semi-minor axis in meters
92
+ )
93
+ List_links .append (link3 )
94
+ joint3 = JointAttributes (
95
+ jointtype = 'Revolute' , # Prismatic joint
96
+ axis = 'z' , # Axis of translation
97
+ )
98
+ List_joints .append (joint3 )
99
+ basis3 = BasisAttributes (
100
+ basistype = 'Chebychev' , # Type of basis
101
+ Bdof = [0 , 1 , 0 , 1 , 0 , 0 ], # Degrees of freedom for each deformation type
102
+ Bodr = [0 , 0 , 0 , 0 , 0 , 0 ], # Order of basis functions for each deformation type
103
+ xi_star = [0 , 0 , 0 , 1 , 0 , 0 ], # Reference strain values as vector
104
+ )
105
+ List_basis .append (basis3 )
106
+ List_nGauss .append (5 ) # Number of Gauss points for the third link
107
+
108
+
109
+ # Create the GVS model
110
+ robot = GVS (
111
+ links_list = List_links ,
112
+ joints_list = List_joints ,
113
+ basis_list = List_basis ,
114
+ n_gauss_list = List_nGauss ,
115
+ gravity_vector = [0 , 0 , 9.81 ]
116
+ )
117
+
118
+ # Initial configuration
119
+ q0 = jnp .ones (robot .dof_tot_system )
120
+ # Initial velocities
121
+ qd0 = jnp .ones (robot .dof_tot_system )
122
+
123
+ # Actuation parameters
124
+ tau = jnp .zeros (robot .dof_tot_system )
125
+ # WARNING: actuation_args need to be a tuple, even if it contains only one element
126
+ # so (tau, ) is necessary NOT (tau) or tau
127
+ actuation_args = (tau ,)
128
+
129
+ # Simulation time parameters
130
+ t0 = 0.0
131
+ t1 = 2.0
132
+ dt = 1e-4
133
+ skip_step = 100 # how many time steps to skip in between video frames
134
+
135
+ # Solver
136
+ solver = Tsit5 () # Runge-Kutta 5(4) method
137
+
138
+ ts , q_ts , q_d_ts = robot .resolve_upon_time (
139
+ q0 = q0 ,
140
+ qd0 = qd0 ,
141
+ actuation_args = actuation_args ,
142
+ t0 = t0 ,
143
+ t1 = t1 ,
144
+ dt = dt ,
145
+ skip_steps = skip_step ,
146
+ max_steps = None ,
147
+ )
148
+
149
+ # =====================================================
150
+ # End-effector position upon time
151
+ # =====================================================
152
+ forward_kinematics = jax .jit (
153
+ partial (
154
+ robot .forward_kinematics ,
155
+ )
156
+ )
157
+ g_ee_ts = jax .vmap (lambda q : forward_kinematics (q )[- 1 ])(q_ts )
158
+
159
+ plt .figure ()
160
+ plt .plot (ts , g_ee_ts [:, 0 , 3 ], label = "End-effector x [m]" )
161
+ plt .plot (ts , g_ee_ts [:, 1 , 3 ], label = "End-effector y [m]" )
162
+ plt .plot (ts , g_ee_ts [:, 2 , 3 ], label = "End-effector z [m]" )
163
+ plt .xlabel ("Time [s]" )
164
+ plt .ylabel ("End-effector position [m]" )
165
+ plt .legend ()
166
+ plt .grid (True )
167
+ plt .box (True )
168
+ plt .tight_layout ()
169
+ plt .show ()
170
+
171
+ fig = plt .figure ()
172
+ ax = fig .add_subplot (111 , projection = "3d" )
173
+ p = ax .scatter (
174
+ g_ee_ts [:, 0 , 3 ], g_ee_ts [:, 1 , 3 ], g_ee_ts [:, 2 , 3 ], c = ts , cmap = "viridis"
175
+ )
176
+ ax .axis ("equal" )
177
+ ax .set_xlabel ("X [m]" )
178
+ ax .set_ylabel ("Y [m]" )
179
+ ax .set_zlabel ("Z [m]" )
180
+ ax .set_title ("End-effector trajectory (3D)" )
181
+ fig .colorbar (p , ax = ax , label = "Time [s]" )
182
+ plt .show ()
183
+
184
+ # # =====================================================
185
+ # # Plot the robot configuration upon time
186
+ # # =====================================================
187
+ # animate_robot_matplotlib(
188
+ # robot,
189
+ # t_list=ts, # shape (T,)
190
+ # q_list=q_ts, # shape (T, DOF)
191
+ # num_points=50,
192
+ # interval=100, # ms
193
+ # slider=True,
194
+ # )
0 commit comments