|
18 | 18 | import java.nio.FloatBuffer; |
19 | 19 | import java.nio.IntBuffer; |
20 | 20 | import java.nio.LongBuffer; |
| 21 | +import java.nio.ShortBuffer; |
21 | 22 | import java.util.Arrays; |
22 | 23 | import java.util.Locale; |
23 | 24 | import org.pytorch.executorch.annotations.Experimental; |
@@ -57,6 +58,7 @@ public abstract class Tensor { |
57 | 58 | private static final int BYTE_SIZE_BYTES = 1; |
58 | 59 | private static final int INT_SIZE_BYTES = 4; |
59 | 60 | private static final int LONG_SIZE_BYTES = 8; |
| 61 | + private static final int HALF_SIZE_BYTES = 2; |
60 | 62 | private static final int FLOAT_SIZE_BYTES = 4; |
61 | 63 | private static final int DOUBLE_SIZE_BYTES = 8; |
62 | 64 |
|
@@ -107,6 +109,18 @@ public static LongBuffer allocateLongBuffer(int numElements) { |
107 | 109 | .asLongBuffer(); |
108 | 110 | } |
109 | 111 |
|
| 112 | + /** |
| 113 | + * Allocates a new direct {@link ShortBuffer} with native byte order and specified capacity that |
| 114 | + * can be used in {@link Tensor#fromBlob(ShortBuffer, long[])}. |
| 115 | + * |
| 116 | + * @param numElements capacity (number of elements) of result buffer. |
| 117 | + */ |
| 118 | + public static ShortBuffer allocateHalfBuffer(int numElements) { |
| 119 | + return ByteBuffer.allocateDirect(numElements * HALF_SIZE_BYTES) |
| 120 | + .order(ByteOrder.nativeOrder()) |
| 121 | + .asShortBuffer(); |
| 122 | + } |
| 123 | + |
110 | 124 | /** |
111 | 125 | * Allocates a new direct {@link DoubleBuffer} with native byte order with specified capacity that |
112 | 126 | * can be used in {@link Tensor#fromBlob(DoubleBuffer, long[])}. |
@@ -187,6 +201,23 @@ public static Tensor fromBlob(float[] data, long[] shape) { |
187 | 201 | return new Tensor_float32(floatBuffer, shape); |
188 | 202 | } |
189 | 203 |
|
| 204 | + /** |
| 205 | + * Creates a new Tensor instance with dtype torch.float16 with specified shape and data as array |
| 206 | + * of IEEE-754 half-precision values encoded in {@code short}s. |
| 207 | + * |
| 208 | + * @param data Tensor elements encoded as 16-bit floats. |
| 209 | + * @param shape Tensor shape |
| 210 | + */ |
| 211 | + public static Tensor fromBlob(short[] data, long[] shape) { |
| 212 | + checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL); |
| 213 | + checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); |
| 214 | + checkShape(shape); |
| 215 | + checkShapeAndDataCapacityConsistency(data.length, shape); |
| 216 | + final ShortBuffer shortBuffer = allocateHalfBuffer((int) numel(shape)); |
| 217 | + shortBuffer.put(data); |
| 218 | + return new Tensor_float16(shortBuffer, shape); |
| 219 | + } |
| 220 | + |
190 | 221 | /** |
191 | 222 | * Creates a new Tensor instance with dtype torch.int64 with specified shape and data as array of |
192 | 223 | * longs. |
@@ -301,6 +332,26 @@ public static Tensor fromBlob(FloatBuffer data, long[] shape) { |
301 | 332 | return new Tensor_float32(data, shape); |
302 | 333 | } |
303 | 334 |
|
| 335 | + /** |
| 336 | + * Creates a new Tensor instance with dtype torch.float16 with specified shape and data. |
| 337 | + * |
| 338 | + * @param data Direct buffer with native byte order that contains {@code Tensor.numel(shape)} |
| 339 | + * elements encoded as IEEE-754 half-precision floats. The buffer is used directly without |
| 340 | + * copying. |
| 341 | + * @param shape Tensor shape |
| 342 | + */ |
| 343 | + public static Tensor fromBlob(ShortBuffer data, long[] shape) { |
| 344 | + checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL); |
| 345 | + checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); |
| 346 | + checkShape(shape); |
| 347 | + checkShapeAndDataCapacityConsistency(data.capacity(), shape); |
| 348 | + checkArgument(data.isDirect(), ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT); |
| 349 | + checkArgument( |
| 350 | + (data.order() == ByteOrder.nativeOrder()), |
| 351 | + ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER); |
| 352 | + return new Tensor_float16(data, shape); |
| 353 | + } |
| 354 | + |
304 | 355 | /** |
305 | 356 | * Creates a new Tensor instance with dtype torch.int64 with specified shape and data. |
306 | 357 | * |
@@ -388,6 +439,16 @@ public byte[] getDataAsByteArray() { |
388 | 439 | "Tensor of type " + getClass().getSimpleName() + " cannot return data as byte array."); |
389 | 440 | } |
390 | 441 |
|
| 442 | + /** |
| 443 | + * @return a Java short array that contains the tensor data interpreted as IEEE-754 half-precision |
| 444 | + * bit patterns. This may be a copy or reference. |
| 445 | + * @throws IllegalStateException if it is called for a non-float16 tensor. |
| 446 | + */ |
| 447 | + public short[] getDataAsShortArray() { |
| 448 | + throw new IllegalStateException( |
| 449 | + "Tensor of type " + getClass().getSimpleName() + " cannot return data as short array."); |
| 450 | + } |
| 451 | + |
391 | 452 | /** |
392 | 453 | * @return a Java byte array that contains the tensor data. This may be a copy or reference. |
393 | 454 | * @throws IllegalStateException if it is called for a non-uint8 tensor. |
@@ -569,6 +630,74 @@ public String toString() { |
569 | 630 | } |
570 | 631 | } |
571 | 632 |
|
| 633 | + static class Tensor_float16 extends Tensor { |
| 634 | + private final ShortBuffer data; |
| 635 | + |
| 636 | + private Tensor_float16(ShortBuffer data, long[] shape) { |
| 637 | + super(shape); |
| 638 | + this.data = data; |
| 639 | + } |
| 640 | + |
| 641 | + @Override |
| 642 | + public DType dtype() { |
| 643 | + return DType.HALF; |
| 644 | + } |
| 645 | + |
| 646 | + @Override |
| 647 | + Buffer getRawDataBuffer() { |
| 648 | + return data; |
| 649 | + } |
| 650 | + |
| 651 | + @Override |
| 652 | + public short[] getDataAsShortArray() { |
| 653 | + data.rewind(); |
| 654 | + short[] arr = new short[data.remaining()]; |
| 655 | + data.get(arr); |
| 656 | + return arr; |
| 657 | + } |
| 658 | + |
| 659 | + @Override |
| 660 | + public float[] getDataAsFloatArray() { |
| 661 | + data.rewind(); |
| 662 | + int remaining = data.remaining(); |
| 663 | + float[] arr = new float[remaining]; |
| 664 | + for (int i = 0; i < remaining; i++) { |
| 665 | + arr[i] = halfBitsToFloat(data.get()); |
| 666 | + } |
| 667 | + return arr; |
| 668 | + } |
| 669 | + |
| 670 | + @Override |
| 671 | + public String toString() { |
| 672 | + return String.format("Tensor(%s, dtype=torch.float16)", Arrays.toString(shape)); |
| 673 | + } |
| 674 | + |
| 675 | + private static float halfBitsToFloat(short halfBits) { |
| 676 | + int h = halfBits & 0xFFFF; |
| 677 | + int sign = (h >>> 15) & 0x1; |
| 678 | + int exp = (h >>> 10) & 0x1F; |
| 679 | + int mant = h & 0x3FF; |
| 680 | + |
| 681 | + if (exp == 0) { |
| 682 | + if (mant == 0) { |
| 683 | + return sign == 0 ? 0.0f : -0.0f; |
| 684 | + } |
| 685 | + float result = mant * 5.9604645e-8f; // 2^-24 |
| 686 | + return sign == 0 ? result : -result; |
| 687 | + } else if (exp == 0x1F) { |
| 688 | + if (mant == 0) { |
| 689 | + return sign == 0 ? Float.POSITIVE_INFINITY : Float.NEGATIVE_INFINITY; |
| 690 | + } |
| 691 | + int bits = (sign << 31) | 0x7f800000 | (mant << 13); |
| 692 | + return Float.intBitsToFloat(bits); |
| 693 | + } else { |
| 694 | + int exp32 = exp + 112; // 127 (float bias) - 15 (half bias) |
| 695 | + int bits = (sign << 31) | (exp32 << 23) | (mant << 13); |
| 696 | + return Float.intBitsToFloat(bits); |
| 697 | + } |
| 698 | + } |
| 699 | + } |
| 700 | + |
572 | 701 | static class Tensor_int64 extends Tensor { |
573 | 702 | private final LongBuffer data; |
574 | 703 |
|
@@ -691,6 +820,8 @@ private static Tensor nativeNewTensor( |
691 | 820 |
|
692 | 821 | if (DType.FLOAT.jniCode == dtype) { |
693 | 822 | tensor = new Tensor_float32(data.asFloatBuffer(), shape); |
| 823 | + } else if (DType.HALF.jniCode == dtype) { |
| 824 | + tensor = new Tensor_float16(data.asShortBuffer(), shape); |
694 | 825 | } else if (DType.INT32.jniCode == dtype) { |
695 | 826 | tensor = new Tensor_int32(data.asIntBuffer(), shape); |
696 | 827 | } else if (DType.INT64.jniCode == dtype) { |
@@ -727,6 +858,11 @@ public byte[] toByteArray() { |
727 | 858 | tensorAsByteArray = new byte[(int) numel()]; |
728 | 859 | Tensor_int8 thiz = (Tensor_int8) this; |
729 | 860 | ByteBuffer.wrap(tensorAsByteArray).put(thiz.getDataAsByteArray()); |
| 861 | + } else if (dtype() == DType.HALF) { |
| 862 | + dtypeSize = HALF_SIZE_BYTES; |
| 863 | + tensorAsByteArray = new byte[(int) numel() * dtypeSize]; |
| 864 | + Tensor_float16 thiz = (Tensor_float16) this; |
| 865 | + ByteBuffer.wrap(tensorAsByteArray).asShortBuffer().put(thiz.getDataAsShortArray()); |
730 | 866 | } else if (dtype() == DType.INT16) { |
731 | 867 | throw new IllegalArgumentException("DType.INT16 is not supported in Java so far"); |
732 | 868 | } else if (dtype() == DType.INT32) { |
@@ -794,6 +930,8 @@ public static Tensor fromByteArray(byte[] bytes) { |
794 | 930 | return new Tensor_uint8(buffer, shape); |
795 | 931 | } else if (dtype == DType.INT8.jniCode) { |
796 | 932 | return new Tensor_int8(buffer, shape); |
| 933 | + } else if (dtype == DType.HALF.jniCode) { |
| 934 | + return new Tensor_float16(buffer.asShortBuffer(), shape); |
797 | 935 | } else if (dtype == DType.INT32.jniCode) { |
798 | 936 | return new Tensor_int32(buffer.asIntBuffer(), shape); |
799 | 937 | } else if (dtype == DType.INT64.jniCode) { |
|
0 commit comments