nxu/bfloat16.tal

354 lines
12 KiB
Tal

( bfloat16.tal )
( )
( This file implements the bfloat16 format. )
( )
( This differs from IEEE float-16 by providing more exponent bits )
( in exchange for fewer mantissa bits. In other words it trades )
( coarser precision for larger numerical range. )
( )
( The bfloat16 value uses 16-bits divided as follows: )
( - sign (1 bit, 0-1) )
( - exponent (8 bits, 0-255) )
( - mantissa (7 bits, 0-127) )
( )
( Kinds of values: )
( - zeros (exponent==0 mantissa==0) )
( - subnormal (exponent==0 mantissa!=0) )
( - infinities (exponent==255 mantissa==0) )
( - nans (exponent==255 mantissa!=0) )
( - normal (everything else) )
( )
( Equations: )
( - normal = -1^sign * 2^(exponent - 127) * (1 + mantissa/128) )
( - subnormal = -1^sign * 2^-126 * mantissa/128 )
( (exponent ranges from 1 to 254 since 0 and 255 are special) )
( )
( VALUE SIGN EXPONENT MANTISSA NOTES )
( 0 0 00000000 0000000 )
( -0 1 00000000 0000000 mostly equivalent to zero )
( 1 0 01111111 0000000 )
( 2 0 10000000 0000000 )
( 3 0 10000000 1000000 )
( -1 1 01111111 0000000 )
( 17 0 10000011 0001000 )
( ~9.184e-41 0 00000000 0000001 smallest positive value )
( ~1.689e38 0 11111110 1111111 largest finite value )
( +inf 0 11111111 0000000 positive infinity )
( -inf 1 11111111 0000000 negative infinity )
( nan * 11111111 ******* lots of nans; * is wild )
( )
( Some hex constants: )
( 0 #0000 )
( -0 #8000 )
( 1 #3f80 )
( -1 #bf80 )
( 2 #4000 )
( +inf #7f80 )
( -inf #ff80 )
( nan #ffff (among others) )
( )
( This code doesn't distinguish between quiet and signaling NaNs. )
( )
( Bfloat16 values are emitted in a hexadecimal format: )
( )
( HEXADECIMAL SIGN EXPONENT MANTISSA DECIMAL )
( 0x1.00p+00 1 10000000 0000000 1.0 )
( 0x0.01p-7f 0 00000000 0000001 ~9.184e-41 )
( -0x1.80p+02 1 10000010 1000000 -6.0 )
( 0x1.c0p+02 0 10000010 1100000 7.0 )
( )
( Eventually I'd like to display integral part of the number )
( in a more natural way but the 1.xx format is OK for now. )
( )
( For consistency zeros are emitted as 0x00p+00 and -0x00p+00. )
( Infinities are "inf" and "-inf" and NaN is "nan". )
%EMIT { #18 DEO }
%SPACE { #20 EMIT }
%NEWLINE { #0a EMIT }
%DEBUG { #ff #0e DEO }
|0100
( #01 ;byte-to-bf16 JSR2 ;test JSR2
#02 ;byte-to-bf16 JSR2 ;test JSR2 )
( #437d -> 0 01010110 1111101 )
( #437c -> 0 01010110 1111100 )
( #00 #86 #7f ;bf16-join JSR2 ;emit-bf16 JSR2 NEWLINE )
( #ff ;byte-to-bf16 JSR2 ;test JSR2 )
( #ff ;byte-to-bf16 JSR2 #01 ;round-shift JSR2 ;test JSR2
#03 ;byte-to-bf16 JSR2 ;test JSR2 )
#7f80 ;test JSR2
#ff80 ;test JSR2
#ff81 ;test JSR2
#0000 ;test JSR2
#8000 ;test JSR2
#0001 ;test JSR2
#8001 ;test JSR2
#3f80 ;test JSR2
#bf80 ;test JSR2
#4000 ;test JSR2
#4080 ;test JSR2
#4100 ;test JSR2
#3f80 ;test JSR2
#3f80 DUP2 ;add-bf16 JSR2 ;test JSR2
#010f DEO BRK
@test ( x* -> )
DUP2 ;emit-u16 JSR2 SPACE
LIT "- EMIT LIT "> EMIT SPACE
;emit-bf16 JSR2 NEWLINE JMP2r
@emit-digit ( d^ -> )
DUP #0a LTH
,&lt-10 JCN #27 ADD
&lt-10 #30 ADD EMIT
JMP2r
@emit-u8 ( n^ -> )
DUP #04 SFT ;emit-digit JSR2
#0f AND ;emit-digit JSR2
JMP2r
@emit-u16 ( x* -> )
SWP ;emit-u8 JSR2
;emit-u8 JSR2
JMP2r
@emit-s8 ( x^ -> )
DUP #07 SFT ,&is-negative JCN LIT "+ EMIT ;emit-u8 JSR2 JMP2r
&is-negative LIT "- EMIT #7f AND #80 SWP SUB ;emit-u8 JSR2 JMP2r
@emit-s16 ( x* -> )
DUP2 #0f SFT2 SWP POP ,&is-negative JCN LIT "+ EMIT ;emit-u16 JSR2 JMP2r
&is-negative LIT "- EMIT #7fff AND2 #8000 SWP2 SUB2 ;emit-u16 JSR2 JMP2r
@emit-bf16 ( x* -> )
;bf16-split JSR2 ( sgn exp mnt )
( sentinel or value )
OVR #ff NEQ ,&non-sentinal JCN
,&is-nan JCN POP #00 EQU ,&pos-inf JCN LIT "- EMIT
&pos-inf LIT "i EMIT LIT "n EMIT LIT "f EMIT JMP2r
&is-nan LIT "n EMIT LIT "a EMIT LIT "n EMIT JMP2r
( zero or non-zero )
&non-sentinal DUP2 ORA ,&non-zero JCN
POP2 ,&is-negative-zero JCN ,&zero-suffix JMP
&is-negative-zero LIT "- EMIT
&zero-suffix LIT "0 EMIT LIT "x EMIT LIT "0 EMIT LIT ". EMIT
#00 ;emit-u8 JSR2 LIT "p EMIT #00 ;emit-s8 JSR2 JMP2r
( normal or subnormal )
&non-zero ROT ,&is-negative JCN ,&post-sgn JMP
&is-negative LIT "- EMIT
&post-sgn LIT "0 EMIT LIT "x EMIT
OVR ,&is-normal JCN LIT "0 ,&suffix JMP &is-normal LIT "1
&suffix EMIT LIT ". EMIT ;emit-u8 JSR2
LIT "p EMIT #7f SUB ;emit-s8 JSR2
JMP2r
@bf16-join ( sgn^ exp^ mta^ -> x* )
STH #00 #01 SFT2 ( sgn^ exp* [mta^] )
ROT #00 #01 SFT2 NIP #00 ORA2 ( sgn|exp* [mta^] )
#00 STHr ORA2 ( sgn|exp|mta* )
JMP2r
( sgn: 0-1, exp: 0-255, mta: 0-127 )
@bf16-split ( x* -> sgn^ exp^ mta^ )
OVR #07 SFT STH ( xhi xlo [sgn] )
#10 SFT2 SWP STHr ( mnt<1 exp sgn )
SWP ( mnt<1 exp sgn )
ROT ( exp sgn mnt<1 )
#01 SFT ( sgn exp mnt )
JMP2r
%SIGN { POP #07 SFT }
%EXPONENT { #10 SFT2 POP }
%MANTISSA { NIP #7f AND }
%MAX { GTHk JMP SWP POP }
( returns full mta: #00 to #ff )
( normal numbers will be >= #80 )
( subnormal numbers will be < #80 )
@full-mantissa ( x* -> fmta^ )
DUP2 MANTISSA STH
EXPONENT ,&is-normal JCN STHr JMP2r
&is-normal #80 STHr ORA JMP2r
@negate-bf16 ( x* -> z* )
#8000 EOR2 JMP2r
@abs-bf16 ( x* -> z* )
#7fff AND2 JMP2r
@is-zero ( x* -> bool^ )
#7fff AND2 #0000 EQU2 JMP2r
@non-zero ( x* -> bool^ )
#7fff AND2 #0000 NEQ2 JMP2r
@is-nan ( x* -> bool^ )
#7fff AND2 #7f80 GTH2 JMP2r
@non-nan ( x* -> bool^ )
#7fff AND2 #7f81 LTH2 JMP2r
@is-inf ( x* -> bool^ )
#7fff AND2 #7f80 EQU2 JMP2r
( not nan, not +/-inf )
@is-finite ( x* -> bool^ )
#7fff AND2 #7f80 LTH2 JMP2r
( Shift mantissa m right by n bits, with rounding )
( )
( We round differently depending on the value to be lost: )
( )
( 1. If the bits to be removed are > 0.5 we round up )
( 2. If the bits to be removed are < 0.5 we round down )
( 3. If the bits to be removed are = 0.5 we: )
( a. Round up if doing so produces an even mantissa )
( b. Round down if diong so produces an even mantissa )
( )
( This method is useful when adding two values that have )
( different exponents. We will want to truncate the value )
( with the smaller exponent to try to shift the mantissa )
( into the range of the larger value. )
( )
( It's important to remember to include the mantissa's )
( leading one value (if any) before calling this method )
@round-shift ( mta^ n^ -> z* )
STH2k ( mta n [mta n] )
#08 SWP SUB ( mta 8-n [mta n] )
STHk ( mta 8-n [8-n mta n] )
#7f SWP SFT ( mta mask=7f>>(8-n) )
AND ( mta&mask )
STHr #01 SWP #10 MUL SFT ( mta&mask lim=1<<(8-n) [mta n] )
DUP2 LTH ,&rnd-down JCN ( masked limit [mta n] )
GTH ,&rnd-up JCN ( [mta n] )
( round-to-even )
STH2r #01 SUB SFT ( mta>>(n-1) )
INC #01 SFT ( (mta>>(n-1)+1)>>1 )
JMP2r
&rnd-down ( masked limit [mta n] )
POP2 STH2r SFT JMP2r
&rnd-up ( [mta n] )
STH2r SFT INC JMP2r
( lift an integer byte into a bfloat16 value )
@byte-to-bf16 ( n^ -> x* )
#86 SWP ( exp n )
&loop
DUP #7f GTH ,&ready JCN
#10 SFT SWP #01 SUB SWP
,&loop JMP
&ready
#7f AND STH ( exp [n&7f] )
#00 #01 SFT2 #00 STHr ORA2
JMP2r
( rules: )
( 1. nan = nan is false )
( 2. x = x is true )
( 3. (x = y) is (y = x) )
( 4. -0 = +0 is true )
@eq-bf16 ( x* y* -> bool^ )
DUP2 ;non-nan JSR2 STH SWP2 ( is y not nan? )
DUP2 ;non-nan JSR2 STH SWP2 ( is x not nan? )
STH2r ORA ,&not-nan JCN ( is either x or y not nan? )
POP2 POP2 #00 JMP2r ( else return false )
&not-nan
DUP2 ;non-zero JSR2 ,&not-zero JCN ( is y non-zero? )
POP2 ;is-zero JSR2 JMP2r ( if y is zero, return x-is-zero )
&not-zero
EQU2 JMP2r ( if not nan or zero, standard comparison )
@ne-bf16 ( x* y* -> bool^ )
;eq-bf16 JSR2 #00 EQU JMP2r
( rules for sentinels (in order): )
( 1. x < x is false )
( 2a. nan < x is false )
( 2b. x < nan is false )
( 3a. x < +inf is true )
( 3b. +inf < x is false )
( 4a. -inf < x is true )
( 4b. x < -inf is false )
( 5. -0 < +0 is false )
( 6. -x < +x (or 0) )
( 7. 0 < +x )
( 8. x*2^p < y*2^q if p < q )
( 9. x*2^p < y*2^p if x < y )
@lt-bf16 ( x* y* -> bool^ )
,&y STR2 ,&x STR2
,&y LDR2 ;non-nan JSR2 ( is y not nan? )
,&x LDR2 ;non-nan JSR2 ( is x not nan? )
ORA ,&not-nan JCN #00 JMP2r ( false if x and y are nan )
&not-nan
,&y ;non-zero JSR2 ( is y non-zero? )
,&x ;non-zero JSR2 ( is x non-zero? )
ORA ,&not-zero JCN #00 JMP2r ( false if x and y are zero )
&not-zero
,&x LDR2 SIGN ( sign of x )
,&y LDR2 SIGN ( sign of y )
EQUk ,&same-sign JCN GTH JMP2r ( return unless signs are eq )
[ &x $2 &y $2 ]
&same-sign
POP ,&is-negative JCN
,&x LDR2 ,&y LDR2 LTH2 JMP2r ( for positives, use integer x < y )
&is-negative
,&x LDR2 ,&y LDR2 GTH2 JMP2r ( for negatives, use integer x > y )
( see lt-bf16; (x < y) = (y < x) )
@gt-bf16 ( x* y* -> bool^ )
SWP2 ;lt-bf16 JMP2
( special cases: )
( 1. x + y = y + x )
( 2. nan + x = nan )
( 3. 0 + x = x )
( 4. inf + (-inf) = nan )
( 5. inf + x = inf )
( 6. -inf + x = -inf )
@add-bf16 ( x* y* -> z* )
DUP2 ;is-nan JSR2 STH SWP2 ( y x [ynan?] )
DUP2 ;is-nan JSR2 STH SWP2 ( x y [xnan? ynan? ] )
STH2r ORA ,&nan JCN ( x y )
DUP2 ;is-inf JSR2 ,&y-inf JCN ( x y )
OVR2 ;is-inf JSR2 ,&x-inf JCN ( x y )
OVR2 OVR2 ( x y x y )
EXPONENT STH EXPONENT STHr ( x* y* ex^ ey^ )
EQUk ,&same-exponent JCN
LTHk ,&smaller-x JCN
SWP STH2 SWP2 STH2r
&smaller-x ( s* b* es^ eb^ )
STHk SWP SUB ( s* b* delta^ [eb] )
&same-epxponent ( x* y* ex^ ey^ )
( TODO: determine exponent, round, and add )
( stack is [rhs lhs] but order doesn't matter )
JMP2r
&nan POP2 POP2 #ffff JMP2r
&y-inf SWP2 #8000 EOR2 EQU2k ,&nan JCN POP2 JMP2r
&x-inf POP2 JMP2r
( TODO )
( lots of stuff including: )
( - subtraction )
( - multiplication )
( - division )
( - floor/ceil/round )
( - min/max )
( - log2/exp2 )