; Our zero-page vars
ox              = $80 ; fixed6.26: center point x
oy              = $84 ; fixed6.26: center point y
cx              = $88 ; fixed6.26: c_x
cy              = $8c ; fixed6.26: c_y

zx              = $90 ; fixed6.26: z_x
zy              = $94 ; fixed6.26: z_y
zx_2            = $98 ; fixed6.26: z_x^2
zy_2            = $9c ; fixed6.26: z_y^2

zx_zy           = $a0 ; fixed6.26: z_x * z_y
dist            = $a4 ; fixed6.26: z_x^2 + z_y^2
sx              = $a8 ; i16: screen pixel x
sy              = $aa ; i16: screen pixel y
z_buffer_active = $ac ; boolean: 1 if we triggered the lake, 0 if not
z_buffer_start  = $ad ; u8: index into z_buffer
z_buffer_end    = $ae ; u8: index into z_buffer
iter            = $af ; u8: iteration count

ptr             = $b0 ; u16
pixel_ptr       = $b2 ; u16
zoom            = $b4 ; u8: zoom shift level
fill_level      = $b5 ; u8
pixel_color     = $b6 ; u8
pixel_mask      = $b7 ; u8
pixel_shift     = $b8 ; u8
pixel_offset    = $b9 ; u8
palette_offset  = $ba ; u8
chroma_offset   = $bb ; u8
palette_ticks   = $bc ; u8
chroma_ticks    = $bd ; u8
count_frames    = $be ; u8
count_pixels    = $bf ; u8

total_pixels    = $c0 ; float48
total_ms        = $c6 ; float48
temp            = $cc ; u16
temp2           = $ce ; u16

palette_delay = 23
chroma_delay = 137


; FP registers in zero page
FR0    = $d4 ; float48
FRE    = $da
FR1    = $e0 ; float48
FR2    = $e6 ; float48
CIX    = $f2 ; u8 - index into INBUFF
INBUFF = $f3 ; u16 - pointer to ascii
FLPTR  = $fc ; u16 - pointer to user buffer float48

CH1    = $02f2 ; previous character read from keyboard
CH     = $02fc ; current character read from keyboard

LBUFF  = $0580 ; result buffer for FASC routine

; FP ROM routine vectors
FASC   = $D8E6 ; FLOATING POINT TO ASCII (output in INBUFF, last char has high bit set)
IFP    = $D9AA ; INTEGER TO FLOATING POINT CONVERSION (FR0:u16 -> FR0:float48)
FADD   = $DA66 ; ADDITION       (FR0 += FR1)
FSUB   = $DA60 ; SUBTRACTION    (FR0 -= FR1)
FMUL   = $DADB ; MULTIPLICATION (FR0 *= FR1)
FDIV   = $DB28 ; DIVISION       (FR0 /= FR1)
ZF1    = $DA46 ; CLEAR ZERO PAGE FLOATING POINT NUMBER (XX)
FLD0R  = $DD89 ; LOAD FR0 WITH FLOATING POINT NUMBER (YYXX)
FLD1R  = $DD98 ; LOAD FR1 WITH FLOATING POINT NUMBER (YYXX)
FST0R  = $DDA7 ; STORE FR0 IN USER BUFFER (YYXX)
FMOVE  = $DDB6 ; MOVE FR0 TO FR1

; High data
framebuffer_top    = $a000
textbuffer         = $af00
framebuffer_bottom = $b000
display_list       = $bf00
framebuffer_end    = $c000

height = 184
half_height = height >> 1
width = 160
half_width = width >> 1
stride = width >> 2

EXTENDED_RAM = $4000 ; 16KiB bank on the XE
PORTB  = $D301 ; memory & bank-switch for XL/XE

DMACTL = $D400
DLISTL = $D402
DLISTH = $D403
WSYNC  = $D40A

; OS shadow registers
SDLSTL = $230
SDLSTH = $231

; interrupt stuff
SYSVBV = $E45F
XITVBV = $E462
SETVBV = $E45C

COLOR0 = $2C4
COLOR1 = $2C5
COLOR2 = $2C6
COLOR3 = $2C7
COLOR4 = $2C8

; Keycodes!
KEY_PLUS  = $06
KEY_MINUS = $0e
KEY_UP    = $8e
KEY_DOWN  = $8f
KEY_LEFT  = $86
KEY_RIGHT = $87
KEY_1     = $1f
KEY_2     = $1e
KEY_3     = $1a
KEY_4     = 24
KEY_5     = 29
KEY_6     = 27
KEY_7     = 51
KEY_8     = 53
KEY_9     = 48
KEY_0     = 50

.struct float48
    exponent .byte
    mantissa .byte 5
.endstruct

.import mul_lobyte256
.import mul_hibyte256
.import mul_hibyte512
.import sqr_lobyte
.import sqr_hibyte

.data

strings:
str_self:
    .byte "MANDEL-6502"
str_self_end:
str_speed:
    .byte " ms/px"
str_speed_end:
str_run:
    .byte " RUN"
str_run_end:
str_done:
    .byte "DONE"
str_done_end:

str_self_len = str_self_end - str_self
str_speed_len = str_speed_end - str_speed
str_run_len = str_run_end - str_run
str_done_len = str_done_end - str_done
speed_precision = 6

speed_start = 40 - str_done_len - str_speed_len - speed_precision - 1
speed_len = 14 + str_speed_len


char_map:
    ; Map ATASCII string values to framebuffer font entries
    ; Sighhhhh
    .repeat 32, i
        .byte i + 64
    .endrepeat
    .repeat 64, i
        .byte i
    .endrepeat
    .repeat 32, i
        .byte 96 + i
    .endrepeat

hex_chars:
    .byte "0123456789abcdef"

aspect:
    ; aspect ratio!
    ; pixels at 320w are 5:6 (narrow)
    ; pixels at 160w are 5:3 (wide)
    ;
    ; cy = (sy << (8 - zoom)) * (96 / 128 = 3 / 4)
    ; cx = (sx << (8 - zoom)) * ((3 / 4) * (5 / 3) = 5 / 4)
    ;
    ; so vertical range -92 .. 91.9 is -2.15625 .. 2.15624
    ; &horizontal range -80 .. 79.9 is -3.125 .. 3.124
    ;
    ; 184h is the equiv of 220.8h at square pixels
    ; 320 / 220.8 = 1.45 display aspect ratio
aspect_x: ; fixed3.13 5/4
    .word 5 << (13 - 2)

aspect_y: ; fixed3.13 3/4
    .word 3 << (13 - 2)

ms_per_frame: ; float48 16.66666667
    .byte 64  ; exponent/sign
    .byte $16 ; BCD digits
    .byte $66
    .byte $66
    .byte $66
    .byte $67

display_list_start:
    ; 24 lines overscan
    .repeat 3
        .byte $70 ; 8 blank lines
    .endrep

    ; 8 scan lines, 1 row of 40-column text
    .byte $42
    .addr textbuffer

    ; 184 lines graphics
    ; ANTIC mode e (160px 2bpp, 1 scan line per line)
    .byte $4e
    .addr framebuffer_top
    .repeat half_height - 1
        .byte $0e
    .endrep
    .byte $4e
    .addr framebuffer_bottom
    .repeat half_height - 1
        .byte $0e
    .endrep

    .byte $41 ; jump and blank
    .addr display_list
display_list_end:
display_list_len = display_list_end - display_list_start

color_map:
    .byte 0
    .repeat 85
        .byte %01010101
        .byte %10101010
        .byte %11111111
    .endrepeat


palette_start:
    .byte $0e
    .byte $08
    .byte $04
palette_repeat:
    .byte $0e
    .byte $08

palette_entries = 3

palette_chroma:
    .repeat 15, i
        .byte (i + 1) << 4
    .endrepeat
    .repeat 2, i
        .byte (i + 1) << 4
    .endrepeat
palette_chroma_entries = 15

.code

;z_buffer_len = 16 ; 10.863 ms/px
;z_buffer_len = 12 ; 10.619 ms/px
z_buffer_len = 8 ; 10.612 ms/px
;z_buffer_len = 4 ; 12.395 ms/px
z_buffer_mask = z_buffer_len - 1
z_buffer:
    ; the last N zx/zy values
    .repeat z_buffer_len
        .word 0
        .word 0
    .endrepeat

.export start

;max_fill_level = 6
max_fill_level = 3
fill_masks:
;    .byte %00011111
;    .byte %00001111
;    .byte %00000111
    .byte %00000011
    .byte %00000001
    .byte %00000000

pixel_masks:
    .byte %11111111
    .byte %11110000
    .byte %11000000

viewport_zoom:
    .byte 0
    .byte 5
    .byte 7
    .byte 5
    .byte 7
    .byte 7

viewport_ox:
    .dword ($00000000 & $3fffffff) << 2
    .dword ($ff110000 & $3fffffff) << 2
    .dword ($ff110000 & $3fffffff) << 2
    .dword ($fe400000 & $3fffffff) << 2
    .dword ($fe3b0000 & $3fffffff) << 2
    .dword $fd220000

viewport_oy:
    .dword ($00000000 & $3fffffff) << 2
    .dword ($ffb60000 & $3fffffff) << 2
    .dword ($ffbe0000 & $3fffffff) << 2
    .dword ($00000000 & $3fffffff) << 2
    .dword ($fffe0000 & $3fffffff) << 2
    .dword $ff000000

; 2 + 9 * byte cycles
.macro add bytes, dest, arg1, arg2
    clc ; 2 cyc
    .repeat bytes, byte ; 9 * byte cycles
        lda arg1 + byte
        adc arg2 + byte
        sta dest + byte
    .endrepeat
.endmacro

; 20 cycles
.macro add16 dest, arg1, arg2
    add 2, dest, arg1, arg2
.endmacro

; 38 cycles
.macro add32 dest, arg1, arg2
    add 4, dest, arg1, arg2
.endmacro

; 8 cycles
.macro add_carry dest
    lda dest ; 3 cyc
    adc #0   ; 2 cyc
    sta dest ; 3 cyc
.endmacro

; 2 + 9 * byte cycles
.macro sub bytes, dest, arg1, arg2
    sec ; 2 cyc
    .repeat bytes, byte ; 9 * byte cycles
        lda arg1 + byte
        sbc arg2 + byte
        sta dest + byte
    .endrepeat
.endmacro

; 20 cycles
.macro sub16 dest, arg1, arg2
    sub 2, dest, arg1, arg2
.endmacro

; 38 cycles
.macro sub32 dest, arg1, arg2
    sub 4, dest, arg1, arg2
.endmacro

; 3 + 5 * bytes cycles
.macro shl bytes, arg
    asl arg              ; 3 cyc
    .repeat bytes-1, i
        rol arg + 1 + i  ; 5 cyc
    .endrepeat
.endmacro

; 13 cycles
.macro shl16 arg
    shl 2, arg
.endmacro

; 18 cycles
.macro shl24 arg
    shl 3, arg
.endmacro

; 23 cycles
.macro shl32 arg
    shl 4, arg
.endmacro

; 6 * bytes cycles
.macro copy bytes, dest, arg
    .repeat bytes, byte ; 6 * bytes cycles
        lda arg + byte  ; 3 cyc
        sta dest + byte ; 3 cyc
    .endrepeat
.endmacro

; 12 cycles
.macro copy16 dest, arg
    copy 2, dest, arg
.endmacro

; 24 cycles
.macro copy32 dest, arg
    copy 4, dest, arg
.endmacro

; 36 cycles
.macro copyfloat dest, arg
    copy 6, dest, arg
.endmacro

; 2 + 8 * byte cycles
.macro neg bytes, arg
    sec ; 2 cyc
    .repeat bytes, byte ; 8 * byte cycles
        lda #00         ; 2 cyc
        sbc arg + byte  ; 3 cyc
        sta arg + byte  ; 3 cyc
    .endrepeat
.endmacro

; 18 cycles
.macro neg16 arg
    neg 2, arg
.endmacro

; 34 cycles
.macro neg32 arg
    neg 4, arg
.endmacro

; 11-27 + 23 * shift cycles
; 103-119 cycles for shift=4
.macro shift_round_16 arg, shift
    .repeat shift
        shl32 arg ; 23 cycles
    .endrepeat
    round16 arg ; 11-27 cycles
.endmacro

; input: arg1, arg2 as fixed4.12
; output: dest as fixed8.24
.macro imul16 dest, arg1, arg2
    copy16 FR0, arg1  ; 12 cyc
    copy16 FR1, arg2  ; 12 cyc
    jsr imul16_func   ; ? cyc
    copy32 dest, FR2  ; 24 cyc
.endmacro

; input: arg as fixed4.12
; output: dest as fixed8.24
.macro sqr16 dest, arg
    copy16 FR0, arg   ; 12 cyc
    jsr sqr16_func    ; ? cyc
    copy32 dest, FR2  ; 24 cyc
.endmacro

; input: arg as u8
; output: dest as u16
; clobbers a, x
.macro sqr8 dest, arg
    ldx arg
    lda sqr_lobyte,x
    sta dest
    lda sqr_hibyte,x
    sta dest + 1
.endmacro

.segment "TABLES"
; lookup table for top byte -> PORTB value for bank-switch
.align 256
bank_switch_table:
    .repeat 256, i
        .byte ((i & $c0) >> 4) | $e3
    .endrepeat

.code

.macro bank_switch bank
    lda #((bank << 2) | $e3)
    sta PORTB
.endmacro

.macro imul8 dest, arg1, arg2, xe
    .if xe
        ; using 64KB lookup table
        ; 51-70 cycles
        ; clobbers x, y, dest, ptr
        .scope
            output = dest

            ; top 2 bits are the table bank selector
            ldx arg2                ; 3 cyc
            lda bank_switch_table,x ; 4 cyc
            sta PORTB               ; 4 cyc

            ; bottom 14 bits except the LSB are the per-bank table index
            ; add $4000 for the bank pointer
            txa          ; 2 cyc
            and #$3f     ; 2 cyc
            ora #$40     ; 2 cyc
            sta ptr + 1  ; 3 cyc

            ; copy the entry into output
            lda arg1     ; 3 cyc
            and #$fe     ; 2 cyc
            tay          ; 2 cyc
            lda (ptr),y  ; 5 cyc
            sta output   ; 3 cyc
            iny          ; 2 cyc
            lda (ptr),y  ; 5 cyc
            sta output+1 ; 3 cyc

            ; note: we are not restoring memory to save 6 cycles!
            ; this means those 16kb have to be switched back to base RAM
            ; if we need to use them anywhere else
            ;;; restore memory
            ;;lda #$81     ; 2 cyc - disabled
            ;;sta PORTB    ; 4 cyc - disabled

            ; check that 1 bit we skipped to fit into space
            lda arg1     ; 3 cyc
            and #1       ; 2 cyc
            beq done     ; 2 cyc

            ; add arg2 one last time for the skipped bit
            clc          ; 2 cyc
            txa          ; 2 cyc
            adc output   ; 3 cyc
            sta output   ; 3 cyc
            lda #0       ; 2 cyc
            adc output+1 ; 3 cyc
            sta output+1 ; 3 cyc

        done:
        .endscope
    .else
        ; Using base 48k RAM compatibility mode
        ; Small table of half squares
        ; Adapted from https://everything2.com/title/Fast+6502+multiplication
        ; 81-92 cycles
        .scope
            mul_factor_a   = arg1
            mul_factor_x   = arg2
            mul_product_lo = dest
            mul_product_hi = dest + 1

            lda mul_factor_a      ; 3 cyc

            ; (a + x)^2/2
            clc                   ; 2 cyc         
            adc mul_factor_x      ; 3 cyc
            tax                   ; 2 cyc
            bcc under256          ; 2 cyc
            lda mul_hibyte512,x   ; 4 cyc
            bcs next              ; 2 cyc
        under256:
            lda mul_hibyte256,x   ; 4 cyc
            sec                   ; 2 cyc
        next:
            sta mul_product_hi    ; 3 cyc
            lda mul_lobyte256,x   ; 4 cyc

            ; - a^2/2
            ldx mul_factor_a      ; 3 cyc
            sbc mul_lobyte256,x   ; 4 cyc
            sta mul_product_lo    ; 3 cyc
            lda mul_product_hi    ; 3 cyc
            sbc mul_hibyte256,x   ; 4 cyc
            sta mul_product_hi    ; 3 cyc

            ; + x & a & 1:
            ; (this is a kludge to correct a
            ; roundoff error that makes odd * odd too low)
            ldx mul_factor_x      ; 3 cyc
            txa                   ; 2 cyc
            and mul_factor_a      ; 3 cyc
            and #1                ; 2 cyc

            clc                   ; 2 cyc
            adc mul_product_lo    ; 3 cyc
            bcc small_product     ; 2 cyc
            inc mul_product_hi    ; 5 cyc

            ; - x^2/2
        small_product:
            sec                   ; 2 cyc
            sbc mul_lobyte256,x   ; 4 cyc
            sta mul_product_lo    ; 3 cyc
            lda mul_product_hi    ; 3 cyc
            sbc mul_hibyte256,x   ; 4 cyc
            sta mul_product_hi    ; 3 cyc
        .endscope
    .endif
.endmacro

.proc imul8xe_init

    bank_switch 0
    lda #0
    sta EXTENDED_RAM
    bank_switch 1
    lda #1
    sta EXTENDED_RAM
    bank_switch 0
    lda EXTENDED_RAM
    beq init

    ; no bank switching available, we just overwrite the value in base ram
    rts

init:

    ; patch imul16_func into a forwarding thunk to imul16xe_func
    lda #$4c ; 'jmp' opcode
    sta imul16_func
    lda #.lobyte(imul16xe_func)
    sta imul16_func + 1
    lda #.hibyte(imul16xe_func)
    sta imul16_func + 2

    ; ditto for sqr16_func -> sqr16xe_func
    lda #$4c ; 'jmp' opcode
    sta sqr16_func
    lda #.lobyte(sqr16xe_func)
    sta sqr16_func + 1
    lda #.hibyte(sqr16xe_func)
    sta sqr16_func + 2

    ; create the lookup table
    ; go through the input set, in four 16KB chunks

    arg1 = FR1
    arg2 = FR2
    result = FR0

    lda #$00
    sta arg1
    sta arg2
    sta ptr
    lda #$40
    sta ptr + 1

    ; $00 * $00 -> $3f * $ff
    bank_switch 0
    jsr imul8xe_init_section

    ; $40 * $00 -> $7f * $ff
    bank_switch 1
    jsr imul8xe_init_section

    ; $80 * $00 -> $bf * $ff
    bank_switch 2
    jsr imul8xe_init_section

    ; $c0 * $00 -> $ff * $ff
    bank_switch 3
    jsr imul8xe_init_section

    rts
.endproc

; Initialize a 16 KB chunk of the table
; input: multipliers in temp
; output: new multipliers in temp
; clobbers: temp, temp2
.proc imul8xe_init_section
    arg1 = FR1
    arg2 = FR2
    result = FR0
    ptr = temp2

    lda #$00
    sta ptr
    lda #$40
    sta ptr + 1

    ldy #0

    ; outer loop: $00 -> $3f
outer_loop:

    ; reset result to 0
    lda #0
    sta result
    sta result + 1

    ; inner loop: $00 -> $ff
inner_loop:

    ; copy result to data set
    lda result
    sta (ptr),y
    lda result + 1
    iny
    sta (ptr),y
    dey

    ; result += 2 * arg2
    clc
    lda arg2
    adc result
    sta result
    lda #0
    adc result + 1
    sta result + 1
    clc
    lda arg2
    adc result
    sta result
    lda #0
    adc result + 1
    sta result + 1

    ; inner loop check
    inc arg1
    inc arg1
    inc ptr
    inc ptr
    bne inner_loop

    ; outer loop check
    inc arg2
    inc ptr + 1
    lda ptr + 1
    cmp #$80
    bne outer_loop

    rts

.endproc

.macro imul16_impl xe
    .local arg1
    .local arg2
    .local result
    .local inter
    .local arg1_pos
    .local arg2_pos
    arg1 = FR0   ; 16-bit arg (clobbered)
    arg2 = FR1   ; 16-bit arg (clobbered)
    result = FR2 ; 32-bit result
    inter = temp2

    ; h1l1 * h2l2
    ; (h1*256 + l1) * (h2*256 + l2)
    ; h1*256*(h2*256 + l2) + l1*(h2*256 + l2)
    ; h1*h2*256*256 + h1*l2*256 + h2*l1*256 + l1*l2

    imul8 result, arg1, arg2, xe

    imul8 result + 2, arg1 + 1, arg2 + 1, xe

    imul8 inter, arg1 + 1, arg2, xe
    add16 result + 1, result + 1, inter
    add_carry result + 3

    imul8 inter, arg1, arg2 + 1, xe
    add16 result + 1, result + 1, inter
    add_carry result + 3

    ; In case of negative inputs, adjust high word
    ; https://stackoverflow.com/a/28827013
    lda arg1 + 1
    bpl arg1_pos
    sub16 result + 2, result + 2, arg2
arg1_pos:
    lda arg2 + 1
    bpl arg2_pos
    sub16 result + 2, result + 2, arg1
arg2_pos:

    rts ; 6 cyc
.endmacro

.macro sqr16_impl xe
    .scope
        arg = FR0    ; 16-bit arg (clobbered)
        result = FR2 ; 32-bit result
        ;inter = temp2
        inter = FR1

        lda arg + 1
        bpl arg_pos
        neg16 arg
    arg_pos:

        ; hl * hl
        ; (h*256 + l) * (h*256 + l)
        ; h*256*(h*256 + l) + l*(h*256 + l)
        ; h*h*256*256 + h*l*256 + h*l*256 + l*l

        sqr8 result, arg

        sqr8 result + 2, arg + 1

        imul8 inter, arg + 1, arg, xe
        add16 result + 1, result + 1, inter
        add_carry result + 3
        add16 result + 1, result + 1, inter
        add_carry result + 3

        rts ; 6 cyc
    .endscope
.endmacro

.proc imul16_func
    imul16_impl 0
.endproc

.proc imul16xe_func
    imul16_impl 1
.endproc

.proc sqr16_func
    sqr16_impl 0
.endproc

.proc sqr16xe_func
    sqr16_impl 1
.endproc

; 11-27 cycles
.macro round16 arg
    ; Round top 16 bits of 32-bit fixed-point number in-place
    .local increment
    .local high_half
    .local check_sign
    .local next

    ; low word > $8000: round up
    ;          = $8000: round up   if positive
    ;                   round down if negative
    ;          < $8000: round down

    ; $8000 17
    ; $8001 27
    ; $8100 21
    ; $7fff 11

    lda arg + 1    ; 3 cyc
    cmp #$80       ; 2 cyc
    beq high_half  ; 2 cyc

    bpl increment  ; 2 cyc

    bmi next       ; 2 cyc

high_half:
    lda arg        ; 3 cyc
    beq check_sign ; 2 cyc

    jmp increment  ; 3 cyc

check_sign:
    lda arg + 3  ; 3 cyc
    bmi next     ; 2 cyc

increment:       ; 5-10 cyc
    inc arg + 2  ; 5 cyc
    bne next     ; 2 cyc
    inc arg + 3  ; 5 cyc

next:

.endmacro

.proc mandelbrot
    ; input:
    ; cx: position scaled to 6.26 fixed point - -32..+31.9
    ; cy: position scaled to 6.26
    ;
    ; output:
    ; iter: iteration count at escape or 0

    ; zx = 0
    ; zy = 0
    ; zx_2 = 0
    ; zy_2 = 0
    ; zx_zy = 0
    ; dist = 0
    ; iter = 0
;    lda #00
;    ldx #(iter - zx + 1)
;initloop:
;    sta zx - 1,x
;    dex
;    bne initloop
;    sta z_buffer_start
;    sta z_buffer_end

    lda #00
    sta zx
    sta zx + 1
    sta zx + 2
    sta zx + 3
    sta zy
    sta zy + 1
    sta zy + 2
    sta zy + 3
    sta zx_2
    sta zx_2 + 1
    sta zx_2 + 2
    sta zx_2 + 3
    sta zy_2
    sta zy_2 + 1
    sta zy_2 + 2
    sta zy_2 + 3
    sta zx_zy
    sta zx_zy + 1
    sta zx_zy + 2
    sta zx_zy + 3
    sta dist
    sta dist + 1
    sta dist + 2
    sta dist + 3
    sta iter
    sta z_buffer_start
    sta z_buffer_end

loop:
    ; iter++ & max-iters break
    inc iter
    bne keep_going
    jmp exit_path
keep_going:

    .macro quick_exit arg, max
        ; arg: fixed6.26
        ; max: integer
        .local positive
        .local negative
        .local nope_out
        .local first_equal
        .local all_done

        ; check sign bit
        lda arg + 3
        bmi negative

    positive:
        cmp #(max << 2)
        bmi all_done ; 'less than'
        jmp exit_path

    negative:
        cmp #(256 - (max << 2))
        beq first_equal ; 'equal' on first byte
        bpl all_done    ; 'greater than'

    nope_out:
        jmp exit_path

    first_equal:
        ; following bytes all 0 shows it's really 'equal'
        lda arg + 2
        bne all_done
        lda arg + 1
        bne all_done
        lda arg
        bne all_done
        jmp exit_path

    all_done:
    .endmacro

    ; 6.26: (-32 .. 31.9)
    ; zx = zx_2  - zy_2  + cx
    sub32 zx, zx_2, zy_2
    add32 zx, zx, cx
    quick_exit zx, 2

    ; zy = zx_zy + zx_zy + cy
    add32 zy, zx_zy, zx_zy
    add32 zy, zy, cy
    quick_exit zy, 2

    ; convert 6.26 -> 3.13: (-4 .. +3.9)
    shift_round_16 zx, 3
    shift_round_16 zy, 3

    ; zx_2 = zx * zx
    sqr16 zx_2, zx + 2

    ; zy_2 = zy * zy
    sqr16 zy_2, zy + 2

    ; zx_zy = zx * zy
    imul16 zx_zy, zx + 2, zy + 2

    ; dist = zx_2 + zy_2
    add32 dist, zx_2, zy_2
    quick_exit dist, 4

    ; if may be in the lake, look for looping output with a small buffer
    ; as an optimization vs running to max iters
    lda z_buffer_active
    beq skip_z_buffer

    ldx z_buffer_start
    cpx z_buffer_end
    beq z_nothing_to_read

z_buffer_loop:
    .macro z_compare arg
        .local compare_no_match
        lda z_buffer,x
        inx
        cmp arg
        bne compare_no_match
        iny
    compare_no_match:
    .endmacro
    .macro z_advance
        .local skip_reset_x
        cpx #(z_buffer_len * 4)
        bmi skip_reset_x
        ldx #0
    skip_reset_x:
    .endmacro
    .macro z_store arg
        lda arg
        sta z_buffer,x
        inx
    .endmacro

    ; Compare the previously stored z values
    ldy #0
    z_compare zx + 2
    z_compare zx + 3
    z_compare zy + 2
    z_compare zy + 3

    cpy #4
    bne z_no_matches
    jmp z_exit

z_no_matches:
    z_advance

    cpx z_buffer_end
    bne z_buffer_loop

z_nothing_to_read:

    ; Store and expand
    z_store zx + 2
    z_store zx + 3
    z_store zy + 2
    z_store zy + 3
    z_advance
    stx z_buffer_end

    ; Increment the start roller if necessary (limit size)
    lda iter
    cmp #(z_buffer_len * 4)
    bmi skip_inc_start
    lda z_buffer_start
    clc
    adc #4
    tax
    z_advance
    stx z_buffer_start
skip_inc_start:

skip_z_buffer:

    jmp loop

z_exit:
    lda #0
    sta iter

exit_path:
    ldx #0
    lda iter
    bne next
    inx
next:
    stx z_buffer_active
    rts

.endproc

.macro scale_zoom dest
    ; clobbers X, flags
    .local cont
    .local enough

    ; cx = (sx << (8 - zoom))
    ldx zoom
cont:
    cpx #8
    beq enough
    shl16 dest
    inx
    jmp cont
enough:
.endmacro

.macro zoom_factor dest, src, aspect
    ; output: dest: fixed6.26
    ; input: src: fixed3.13
    ; aspect: fixed3.13
    ; clobbers A, X, flags, etc
    copy16 dest, src
    scale_zoom dest

    ; cy = cy * (3 / 4)
    ; cx = cx * (5 / 4)
    imul16 dest, dest, aspect
.endmacro

.proc pset
    ; screen coords in signed sx,sy
    ; iter holds the target to use
    ; @todo implement

    ; iter -> color
    ldx iter
    lda color_map,x
    ldx fill_level
    and pixel_masks,x
    sta pixel_color
    lda pixel_masks,x
    eor #$ff
    sta pixel_mask

    ; sy -> line base address in temp
    lda sy
    bpl positive

negative:
    ; temp1 = top half
    lda #.lobyte(framebuffer_top + stride * half_height)
    sta pixel_ptr
    lda #.hibyte(framebuffer_top + stride * half_height)
    sta pixel_ptr + 1
    jmp point

positive:

    lda #.lobyte(framebuffer_bottom)
    sta pixel_ptr
    lda #.hibyte(framebuffer_bottom)
    sta pixel_ptr + 1

point:

    ; pixel_ptr += sy * stride
    ;    temp * 40
    ; =  temp * 32  +  temp * 8
    ; = (temp << 5) + (temp << 3)
    copy16 temp, sy
    shl16 temp
    shl16 temp
    shl16 temp
    add16 pixel_ptr, pixel_ptr, temp
    shl16 temp
    shl16 temp
    add16 pixel_ptr, pixel_ptr, temp

    ; Ok so temp1 points to the start of the line, which is 40 bytes.
    ; Get the byte and bit offsets
    lda sx
    clc
    adc #half_width
    sta temp

    ; pixel_shift = temp & 3
    ; pixel_color <<= pixel_shift (shifting in zeros)
    ; pixel_mask <<= pixel_shift (shifting in ones)
    and #3
    sta pixel_shift
    tax
shift_loop:
    beq shift_done
    lsr pixel_color
    lsr pixel_color
    sec
    ror pixel_mask
    sec
    ror pixel_mask
    dex
    jmp shift_loop
shift_done:

    ldy fill_level
    ldx fill_masks,y
    inx

    ; pixel_offset = temp >> 2
    lda temp
    lsr a
    lsr a
    sta pixel_offset
    tay

draw_pixel:
    ; read, mask, or, write
    lda (pixel_ptr),y
    and pixel_mask
    ora pixel_color
    sta (pixel_ptr),y

    dex
    beq done
    clc
    lda #40
    adc pixel_ptr
    sta pixel_ptr
    lda #0
    adc pixel_ptr + 1
    sta pixel_ptr + 1
    jmp draw_pixel

done:
    rts
.endproc

.macro draw_text_indirect col, len, strptr
    ; clobbers A, X
    .local loop
    .local done
    ldx #0
loop:
    cpx #len
    beq done
    txa
    tay
    lda (strptr),y
    tay
    lda char_map,y
    sta textbuffer + col,x
    inx
    jmp loop
done:
.endmacro

.macro draw_text col, len, cstr
    ; clobbers A, X
    .local loop
    .local done
    ldx #0
loop:
    cpx #len
    beq done
    ldy cstr,x
    lda char_map,y
    sta textbuffer + col,x
    inx
    jmp loop
done:
.endmacro

.proc vblank_handler
    inc count_frames

    inc chroma_ticks
    lda chroma_ticks
    cmp #(chroma_delay)
    bne skip_chroma

    lda #0
    sta chroma_ticks

    inc chroma_offset
    lda chroma_offset
    cmp #(palette_chroma_entries)
    bne skip_chroma

    lda #0
    sta chroma_offset
skip_chroma:

    inc palette_ticks
    lda palette_ticks
    cmp #(palette_delay)
    bne skip_luma

    lda #0
    sta palette_ticks

    inc palette_offset
    lda palette_offset
    cmp #(palette_entries)
    bne skip_luma

    lda #0
    sta palette_offset

skip_luma:
    jsr update_palette
    jmp XITVBV
.endproc

.proc update_palette
    lda #0
    sta COLOR4

    ldx chroma_offset
    ldy palette_offset
    lda palette_chroma,x
    ora palette_start,y
    sta COLOR2

    ;inx
    iny
    lda palette_chroma,x
    ora palette_start,y
    sta COLOR1

    ;inx
    iny
    lda palette_chroma,x
    ora palette_start,y
    sta COLOR0

    rts
.endproc

.proc update_speed
    ; convert frames (u16) to fp
    ; add to frames_total
    ; convert pixels (u16) to fp
    ; add to pixels_total
    ; (frames_total * 16.66666667) / pixels_total
    ; convert to ATASCII
    ; draw text
.endproc

.proc keycheck
    ; clobbers all
    ; returns 255 in A if state change or 0 if no change

    ; check keyboard buffer
    lda CH
    cmp #$ff
    beq skip_char

    ; Clear the keyboard buffer and re-enable interrupts
    ldx #$ff
    stx CH

    tay

    lda zoom
    cpy #KEY_PLUS
    beq plus
    cpy #KEY_MINUS
    beq minus

    ; temp+temp2 = $00010000 << (8 - zoom)
    lda #$00
    sta temp
    sta temp + 1
    lda #$01
    sta temp + 2
    lda #$00
    sta temp + 3
    scale_zoom temp + 2

    cpy #KEY_UP
    beq up
    cpy #KEY_DOWN
    beq down
    cpy #KEY_LEFT
    beq left
    cpy #KEY_RIGHT
    beq right
    jmp number_keys
 
skip_char:
    lda #0
    rts

plus:
    lda zoom
    cmp #7
    bpl skip_char
    inc zoom
    jmp done
minus:
    lda zoom
    cmp #1
    bmi skip_char
    dec zoom
    jmp done
up:
    sub32 oy, oy, temp
    jmp done
down:
    add32 oy, oy, temp
    jmp done
left:
    sub32 ox, ox, temp
    jmp done
right:
    add32 ox, ox, temp
    jmp done

number_keys:
    cpy #KEY_1
    beq one
    cpy #KEY_2
    beq two
    cpy #KEY_3
    beq three
    cpy #KEY_4
    beq four
    cpy #KEY_5
    beq five
    cpy #KEY_6
    beq six
    jmp skip_char

one:
    ldx #0
    jmp load_key_viewport
two:
    ldx #1
    jmp load_key_viewport
three:
    ldx #2
    jmp load_key_viewport
four:
    ldx #3
    jmp load_key_viewport
five:
    ldx #4
    jmp load_key_viewport
six:
    ldx #5
    ; fall through
load_key_viewport:
    jsr load_viewport
    ; fall through
done:
    lda #255
    rts

.endproc

.proc clear_screen
    ; zero the range from framebuffer_top to display_list
    lda #.lobyte(framebuffer_top)
    sta temp
    lda #.hibyte(framebuffer_top)
    sta temp + 1

zero_page_loop:
    lda #0
    ldy #0
zero_byte_loop:
    sta (temp),y
    iny
    bne zero_byte_loop

    inc temp + 1
    lda temp + 1
    cmp #.hibyte(display_list)
    bne zero_page_loop

    rts
.endproc

.proc status_bar
    ; Status bar
    draw_text 0, str_self_len, str_self
    draw_text 40 - str_run_len, str_run_len, str_run

    rts
.endproc

; input: viewport selector in x
; clobbers: a, x
.proc load_viewport

    lda viewport_zoom,x
    sta zoom

    txa
    asl a
    asl a

    tax
    lda viewport_ox,x
    sta ox
    lda viewport_oy,x
    sta oy

    inx
    lda viewport_ox,x
    sta ox + 1
    lda viewport_oy,x
    sta oy + 1

    inx
    lda viewport_ox,x
    sta ox + 2
    lda viewport_oy,x
    sta oy + 2

    inx
    lda viewport_ox,x
    sta ox + 3
    lda viewport_oy,x
    sta oy + 3

    rts
.endproc

.proc start

    jsr imul8xe_init

    ; initialize viewport
    ldx #0 ; overview
    jsr load_viewport

    ; Disable display DMA
    lda #0
    sta DMACTL

    jsr clear_screen

    ; Copy the display list into properly aligned memory
    ; Can't cross 1024-byte boundaries :D
    ldx #0
copy_byte_loop:
    lda display_list_start,x
    sta display_list,x
    inx
    cpx #display_list_len
    bne copy_byte_loop

    ; Set up the display list
    lda #.lobyte(display_list)
    sta DLISTL ; actual register
    sta SDLSTL ; shadow register the OS will copy in
    lda #.hibyte(display_list)
    sta DLISTH ; actual register
    sta SDLSTH ; shadow register the OS will copy in

    ; Re-enable display DMA
    lda #$22
    sta DMACTL

    ; Initialize the palette
    lda #0
    sta palette_offset
    sta palette_delay
    sta chroma_offset
    sta chroma_delay
    jsr update_palette

    ; install the vblank handler
    lda #7 ; deferred
    ldx #.hibyte(vblank_handler)
    ldy #.lobyte(vblank_handler)
    jsr SETVBV

main_loop:
    ; count_frames = 0; count_pixels = 0
    lda #0
    sta count_frames
    sta count_pixels

    ; total_ms = 0.0; total_pixels = 0.0
    ldx #total_ms
    jsr ZF1
    ldx #total_pixels
    jsr ZF1

    jsr clear_screen
    jsr status_bar

    lda #0
    sta fill_level

fill_loop:

    ; sy = -92 .. 91
    lda #(256-half_height)
    sta sy
    lda #(256-1)
    sta sy + 1

loop_sy:
    ; sx = -80 .. 79
    lda #(256-half_width)
    sta sx
    lda #(256-1)
    sta sx + 1

loop_sx:
    ; check the fill mask
    ldy #0

loop_skip_level:
    cpy fill_level
    beq current_level

    lda fill_masks,y
    and sx
    bne not_skipped_mask1

    lda fill_masks,y
    and sy
    beq skipped_mask

not_skipped_mask1:
    iny
    jmp loop_skip_level

current_level:
    lda fill_masks,y
    and sx
    bne skipped_mask

    lda fill_masks,y
    and sy
    beq not_skipped_mask

skipped_mask:
    jmp skipped

not_skipped_mask:

    ; run the fractal!
    zoom_factor cx, sx, aspect_x
    add32 cx, cx, ox
    zoom_factor cy, sy, aspect_y
    add32 cy, cy, oy
    jsr mandelbrot
    jsr pset

    jsr keycheck
    beq no_key
    ; @fixme clear the pixel stats
    jmp main_loop

no_key:
    ; check if we should update the counters
    ;
    ; count_pixels >= width? update!
    inc count_pixels
    lda count_pixels
    cmp #width
    bmi update_status

    ; count_frames >= 120? update!
    lda count_frames
    cmp #120 ; >= 2 seconds
    bmi skipped

update_status:
    ; FR0 = (float)count_pixels & clear count_pixels
    lda count_pixels
    sta FR0
    lda #0
    sta FR0 + 1
    sta count_pixels
    jsr IFP

    ; FR1 = total_pixels
    ldx #.lobyte(total_pixels)
    ldy #.hibyte(total_pixels)
    jsr FLD1R

    ; FR0 += FR1
    jsr FADD

    ; total_pixels = FR0
    ldx #.lobyte(total_pixels)
    ldy #.hibyte(total_pixels)
    jsr FST0R


    ; FR0 = (float)count_frames & clear count_frames
    ; warning: this should really disable interrupts @TODO
    lda count_frames
    sta FR0
    lda #0
    sta FR0 + 1
    sta count_frames
    jsr IFP

    ; FR0 *= ms_per_frame
    ldx #.lobyte(ms_per_frame)
    ldy #.hibyte(ms_per_frame)
    jsr FLD1R
    jsr FMUL

    ; FR0 += total_ms
    ldx #total_ms
    ldy #0
    jsr FLD1R
    jsr FADD

    ; total_ms = FR0
    ldx #total_ms
    ldy #0
    jsr FST0R

    ; FR0 /= total_pixels
    ldx #total_pixels
    ldy #0
    jsr FLD1R
    jsr FDIV

    ; convert to ASCII in INBUFF
    jsr FASC

    ; print the first 6 digits
    draw_text_indirect speed_start, speed_precision, INBUFF
    draw_text speed_start + speed_precision, str_speed_len, str_speed

skipped:

    ; sx += fill_level[fill_masks] + 1
    ldx fill_level
    lda fill_masks,x
    clc
    adc #1 ; will never carry
    adc sx
    sta sx
    lda #0
    adc sx + 1
    sta sx + 1

    lda sx
    cmp #half_width
    beq loop_sx_done
    jmp loop_sx

loop_sx_done:

    ; sy += fill_level[fill_masks] + 1
    ldx fill_level
    lda fill_masks,x
    clc
    adc #1 ; will never carry
    adc sy
    sta sy
    lda #0
    adc sy + 1
    sta sy + 1

    lda sy
    cmp #half_height
    beq loop_sy_done
    jmp loop_sy

loop_sy_done:

fill_loop_done:
    inc fill_level
    lda fill_level
    cmp #max_fill_level
    beq loop
    jmp fill_loop

loop:
    ; finished
    draw_text 40 - str_done_len, str_done_len, str_done
    jsr keycheck
    beq loop
    jmp main_loop

.endproc