; Our zero-page vars
sx    = $80     ; 8 bits: screen pixel x
sy    = $81     ; 8 bits: screen pixel y
cx    = $82     ; 16 bits fixed point
cy    = $84     ; 16 bits fixed point
zx    = $86     ; 16 bits fixed point
zy    = $88     ; 16 bits fixed point
zx_2  = $8a     ; 32 bits fixed point
zy_2  = $8e     ; 32 bits fixed point
zx_zy = $92     ; 32 bits fixed point
dist  = $96     ; 32 bits fixed point  
iter  = $9a     ; 8 bits iteration count

temp  = $a0     ; debug temp area

; FP registers in zero page
FR0 = $d4
FRE = $da
FR1 = $e0
FR2 = $e6

.code

.export start

; 2 + 9 * byte cycles
.macro add bytes, dest, arg1, arg2
    clc ; 2 cyc
    .repeat bytes, byte ; 9 * byte cycles
        lda arg1 + byte
        adc arg2 + byte
        sta dest + byte
    .endrepeat
.endmacro

.macro add16 dest, arg1, arg2
    add 2, dest, arg1, arg2
.endmacro

.macro add32 dest, arg1, arg2
    add 2, dest, arg2, dest
.endmacro

; 2 + 9 * byte cycles
.macro sub bytes, dest, arg1, arg2
    sec ; 2 cyc
    .repeat bytes, byte ; 9 * byte cycles
        lda arg1 + byte
        sbc arg2 + byte
        sta dest + byte
    .endrepeat
.endmacro

.macro sub16 dest, arg1, arg2
    sub 2, dest, arg1, arg2
.endmacro

.macro sub32 dest, arg1, arg2
    sub 4, dest, arg1, arg2
.endmacro

.macro shl bytes, arg
    asl arg
    .repeat bytes-1
        rol arg
    .endrepeat
.endmacro

.macro shl16 arg
    shl 2, arg
.endmacro

.macro shl24 arg
    shl 3, arg
.endmacro

.macro shl32 arg
    shl 4, arg
.endmacro

; 6 * bytes cycles
.macro copy bytes, dest, arg
    .repeat bytes, byte ; 6 * bytes cycles
        lda arg + byte  ; 3 cyc
        sta dest + byte ; 3 cyc
    .endrepeat
.endmacro

.macro copy16 dest, arg
    copy 2, dest, arg
.endmacro

.macro copy32 dest, arg
    copy 4, dest, arg
.endmacro

; 2 + 8 * byte cycles
.macro neg bytes, arg
    sec ; 2 cyc
    .repeat bytes, byte ; 8 * byte cycles
        lda #00         ; 2 cyc
        sbc arg + byte  ; 3 cyc
        sta arg + byte  ; 3 cyc
    .endrepeat
.endmacro

; 18 cycles
.macro neg16 arg
    neg 2, arg
.endmacro

; 34 cycles
.macro neg32 arg
    neg 4, arg
.endmacro

; inner loop for imul16
; bitnum < 8: 25 or 41 cycles
; bitnum >= 8: 30 or 46 cycles
.macro bitmul16 arg1, arg2, result, bitnum
    .local zero
    .local one
    .local next

    ; does 16-bit adds
    ; arg1 and arg2 are treated as unsigned
    ; negative signed inputs must be flipped first

    ; 7 cycles up to the branch

    ; check if arg1 has 0 or 1 bit in this place
    ; 5 cycles either way
    .if bitnum < 8
        lda arg1                 ; 3 cyc
        and #(1 << bitnum)       ; 2 cyc
    .else
        lda arg1 + 1             ; 3 cyc
        and #(1 << (bitnum - 8)) ; 2 cyc
    .endif
    bne one ; 2 cyc

zero: ; 18 cyc, 23 cyc
    lsr result + 3 ; 5 cyc
    jmp next       ; 3 cyc

one: ; 32 cyc, 37 cyc
    ; 16-bit add on the top bits
    clc            ; 2 cyc
    lda result + 2 ; 3 cyc
    adc arg2       ; 3 cyc
    sta result + 2 ; 3 cyc
    lda result + 3 ; 3 cyc
    adc arg2 + 1   ; 3 cyc
    ror a          ; 2 cyc - get a jump on the shift
    sta result + 3 ; 3 cyc
next:
    ror result + 2 ; 5 cyc
    ror result + 1 ; 5 cyc
    .if bitnum >= 8
        ; we can save 5 cycles * 8 bits = 40 cycles total by skipping this byte
        ; when it's all uninitialized data
        ror result ; 5 cyc
    .endif


.endmacro

; 5 to 25 cycles
.macro check_sign arg
    ; Check sign bit and flip argument to postive,
    ; keeping a count of sign bits in the X register.
    .local positive
    lda arg + 1   ; 3 cyc
    bpl positive  ; 2 cyc
    neg16 arg     ; 18 cyc
    inx           ; 2 cyc
positive:
.endmacro

; 518 - 828 cyc
.macro imul16 dest, arg1, arg2
    copy16 FR0, arg1  ; 12 cyc
    copy16 FR1, arg2  ; 12 cyc
    jsr imul16_func   ; 470-780
    copy32 dest, FR2  ; 24 cyc
.endmacro

; min 470 cycles
; max 780 cycles
.proc imul16_func
    arg1 = FR0   ; 16-bit arg (clobbered)
    arg2 = FR1   ; 16-bit arg (clobbered)
    result = FR2 ; 32-bit result

    ldx #0          ; 2 cyc
    ; counts the number of sign bits in X
    check_sign arg1 ; 5 to 25 cyc
    check_sign arg2 ; 5 to 25 cyc
    
    ; zero out the 32-bit temp's top 16 bits
    lda #0          ; 2 cyc
    sta result + 2  ; 3 cyc
    sta result + 3  ; 3 cyc
    ; the bottom two bytes will get cleared by the shifts

    ; unrolled loop for maximum speed, at the cost
    ; of a larger routine
    ; 440 to 696 cycles
    .repeat 16, bitnum
        ; bitnum < 8: 25 or 41 cycles
        ; bitnum >= 8: 30 or 46 cycles
        bitmul16 arg1, arg2, result, bitnum
    .endrepeat

    ; In case of mixed input signs, return a negative result.
    cpx #1              ; 2 cyc
    bne positive_result ; 2 cyc
    neg32 result        ; 34 cyc
positive_result:

    rts ; 6 cyc
.endproc

.macro round16 arg
    ; Round top 16 bits of 32-bit fixed-point number in-place
    .local zero
    .local one
    .local positive
    .local negative
    .local neg2
    .local next

    ; no round            - 5 cycles
    ; round pos, no carry - 17
    ; round pos, carry    - 22
    ; round neg, no carry - 23
    ; round neg, carry    - 28
    ; average = 5 / 2 + (17 + 22 + 23 + 28) / 8
    ;         = 5 / 2 + 90 / 8
    ;         = 2.5 + 11.25 = 13.75 cycles average on evenly distributed input

    lda arg + 1  ; 3 cyc
    bpl zero     ; 2 cyc

one:
    ; check sign bit
    lda arg + 3  ; 3 cyc
    bpl positive ; 2 cyc

negative:
    lda arg + 2  ; 3 cyc
    beq neg2     ; 2 cyc

    dec arg + 2  ; 5 cyc
    jmp next     ; 3 cyc

neg2:
    dec arg + 2  ; 5 cyc
    dec arg + 3  ; 5 cyc
    jmp next     ; 3 cyc

positive:
    inc arg + 2  ; 5 cyc
    beq next     ; 2 cyc
    inc arg + 3  ; 5 cyc

zero:
next:

.endmacro



.proc mandelbrot
    ; input:
    ; cx: position scaled to 4.12 fixed point - -8..+7.9
    ; cy: position scaled to 4.12
    ;
    ; output:
    ; iter: iteration count at escape or 0

    ; zx = 0
    ; zy = 0
    ; zx_2 = 0
    ; zy_2 = 0
    ; zx_zy = 0
    ; dist = 0
    ; iter = 0
    lda #00
    ldx iter - zx
initloop:
    sta zx,x
    dex
    bne initloop

loop:
    ; 1939 - 3007 cyc

    ; iter++ & max-iters break = 7 cyc
    inc iter       ; 5 cyc
    bne keep_going ; 2 cyc
    rts
keep_going:

    ; 4.12: (-8 .. +7.9)
    ; zx = zx_2  - zy_2  + cx   = 3 * 20 = 60 cyc
    sub16 zx, zx_2, zy_2
    add16 zx, zx, cx

    ; zy = zx_zy + zx_zy + cy   = 3 * 20 = 60 cyc
    sub16 zy, zx_zy, zx_zy
    add16 zy, zy, cy

    ; 8.24: (-128 .. +127.9)
    ; zx_2 = zx * zx            = 518 - 828 cyc
    imul16 zx_2, zx, zx

    ; zy_2 = zy * zy            = 518 - 828 cyc
    imul16 zy_2, zy, zy

    ; zx_zy = zx * zy           = 518 - 828 cyc
    imul16 zx_zy, zx, zy

    ; dist = zx_2 + zy_2        = 38 cyc
    add32 dist, zx_2, zy_2

    ; if dist >= 4 break, else continue iterating = 7 cyc
    lda dist + 3  ; 3 cyc
    cmp #4        ; 2 cyc
    bmi still_in  ; 2 cyc
    rts
still_in:

    ; shift and round zx_2 to 4.12 = (60 + 5) - (60 + 28) = 65 - 88 cyc
    .repeat 4      ; 60 cyc
        shl24 zx_2 ; 15 cyc
    .endrepeat
    round16 zx_2   ; 5-28 cycles

    ; shift and round zy_2 to 4.12 = (20 + 5) - (20 + 28) = 65 - 88 cyc
    .repeat 4      ; 60 cyc
        shl24 zy_2 ; 15 cyc
    .endrepeat
    round16 zy_2   ; 5-28 cycles

    ; shift and round zx_zy to 4.12 = (20 + 5) - (20 + 28) = 65 - 88 cyc
    .repeat 4       ; 60 cyc
        shl24 zx_zy ; 15 cyc
    .endrepeat
    round16 zx_zy   ; 5-28 cycles

    ; if may be in the lake, look for looping output with a small buffer
    ; as an optimization vs running to max iters
    jmp loop ; 3 cycles

.endproc

.proc start

looplong:
    ; cx = -0.5
    lda #$f7
    sta cx
    lda #$ff
    sta cx + 1

    ; cy = 1
    lda #$10
    sta cy
    lda #$00
    sta cy + 1

    jsr mandelbrot
    ; should have 32-bit -15 in FR2

    ; save the completed iter count for debugging
    lda iter
    sta temp

loop:
    ; keep looping over so we can work in the debugger
    jmp looplong
.endproc