jax conversion of trace.py in ray-optics #103
Replies: 3 comments 6 replies
-
Hello, |
Beta Was this translation helpful? Give feedback.
-
Try changing
to
If that doesn't work, I'm not clear what the problem is. |
Beta Was this translation helpful? Give feedback.
-
Replacing the computational engine that ray-optics uses might best be approached from the bottom-up, in terms of the program structure. For ray-optics, the core ray trace calculations are done in the raytrace module and the profiles module. The trace_raw function is the core ray trace loop. There are functions for computing the optical interaction (bend, reflect, phase) in the The sequential path definition is the key data structure that the ray trace loop runs over. The sequential path is a list of tuples, each tuple consisting of (using python type hints):
where:
The path and reverse_path methods of the sequential model return The An example of a "minimal" code ray trace test is in the rayoptics/raytr/tests directory called test_sequential.py. This uses a function gen_sequence to provide the sequential path definition for the trace_raw call. This example touches the smallest footprint of the ray-optics code while traversing most paths through the core ray trace code. Once the low level ray trace is working, the next area to include might be the analyses module, for tracing grids and lists of rays. I think what you're trying to do is interesting and ambitious. For a code base like ray-optics, it is important to understand how the modules fit together and how they are layered, from computation outward to the user interface. If I can help you by documenting how these different things fit together (and which things you don't have to worry about), I'm more than happy to answer questions and fill in details. But I can't do line-by-line debugging; I hope you understand. |
Beta Was this translation helpful? Give feedback.
-
I am trying to convert trace.py into jax. So for that instead of using numpy libraries, I am importing equivalent jax libraries. The original trace.py was using "Newton" and "fsolove" from scipy.optimize but jax.scipy doesn't have newton/fsolve. So I am trying to use "minimize" for the same operation. I have attached the code below, that's the only function I have changed.
In the sub function "y_stop_coordinate" ," pt1 = np.array([0., y1, dist])" This line is throwing ValueError "All input arrays must have the same shape" y1 is of type class <jax.interpreters.ad.JVPTracer> and dist is of type float.
out = stack([asarray(elt,dtype=dtype)]) in lax_numpy.py (line number 1900) ) is causing this.
Please suggest how can I overcome this error ? (Given that I want the files to be converted in jax for auto differentiation calculation)
Beta Was this translation helpful? Give feedback.
All reactions