; Our zero-page vars
sx    = $80     ; i16: screen pixel x
sy    = $82     ; i16: screen pixel y
ox    = $84     ; fixed3.13: center point x
oy    = $86     ; fixed3.13: center point y
cx    = $84     ; fixed3.13: c_x
cy    = $86     ; fixed3.13: c_y
zx    = $88     ; fixed3.13: z_x
zy    = $8a     ; fixed3.13: z_y

zx_2  = $90     ; fixed6.26: z_x^2
zy_2  = $94     ; fixed6.26: z_y^2
zx_zy = $98     ; fixed6.26: z_x * z_y
dist  = $9c     ; fixed6.26: z_x^2 + z_y^2

iter  = $a0     ; u8: iteration count
zoom  = $a1     ; u8: zoom shift level
temp  = $a2     ; u16
temp2 = $a4     ; u16

pixel_ptr    = $b0 ; u16
pixel_color  = $b2 ; u8
pixel_mask   = $b3 ; u8
pixel_shift  = $b4 ; u8
pixel_offset = $b5 ; u8

; FP registers in zero page
FR0 = $d4
FRE = $da
FR1 = $e0
FR2 = $e6

; High data
framebuffer_top    = $8000
textbuffer         = $8f00
framebuffer_bottom = $9000
display_list       = $9f00
framebuffer_end    = $a000

height = 184
half_height = height >> 1
width = 160
half_width = width >> 1
stride = width >> 2
width_ratio_3_13 = (5 << 11) ; 5/4
height_ratio_3_13 = (3 << 11) ; 5/4

DMACTL = $D400
DLISTL = $D402
DLISTH = $D403

; OS shadow registers
SDLSTL = $230
SDLSTH = $231

.data

strings:
str_self:
    .byte "MANDEL-6502"
str_self_end:
str_speed:
    .byte "ms/px"
str_speed_end:
str_run:
    .byte " RUN"
str_run_end:
str_done:
    .byte "DONE"
str_done_end:

str_self_len = str_self_end - str_self
str_speed_len = str_speed_end - str_speed
str_run_len = str_run_end - str_run
str_done_len = str_done_end - str_done

char_map:
    ; Map ATASCII string values to framebuffer font entries
    ; Sighhhhh
    .repeat 32, i
        .byte i + 64
    .endrepeat
    .repeat 64, i
        .byte i
    .endrepeat
    .repeat 32, i
        .byte 96 + i
    .endrepeat

aspect:
    ; aspect ratio!
    ; pixels at 320w are 5:6 (narrow)
    ; pixels at 160w are 5:3 (wide)
    ;
    ; cy = (sy << (8 - zoom)) * (96 / 128 = 3 / 4)
    ; cx = (sx << (8 - zoom)) * ((3 / 4) * (5 / 3) = 5 / 4)
    ;
    ; so vertical range -92 .. 91.9 is -2.15625 .. 2.15624
    ; &horizontal range -80 .. 79.9 is -3.125 .. 3.124
    ;
    ; 184h is the equiv of 220.8h at square pixels
    ; 320 / 220.8 = 1.45 display aspect ratio
aspect_x:
    .word 5 << (13 - 2)

aspect_y:
    .word 3 << (13 - 2)


bit_masks:
    .byte 3
    .byte 3 << 2
    .byte 3 << 4
    .byte 3 << 6

display_list_start:
    ; 24 lines overscan
    .repeat 3
        .byte $70 ; 8 blank lines
    .endrep

    ; 8 scan lines, 1 row of 40-column text
    .byte $42
    .addr textbuffer

    ; 184 lines graphics
    ; ANTIC mode e (160px 2bpp, 1 scan line per line)
    .byte $4e
    .addr framebuffer_top
    .repeat half_height - 1
        .byte $0e
    .endrep
    .byte $4e
    .addr framebuffer_bottom
    .repeat half_height - 1
        .byte $0e
    .endrep

    .byte $41 ; jump and blank
    .addr display_list
display_list_end:
display_list_len = display_list_end - display_list_start

color_map:
    .byte 0
    .repeat 85
        .byte 1
        .byte 2
        .byte 3
    .endrepeat

.code

.export start

; 2 + 9 * byte cycles
.macro add bytes, dest, arg1, arg2
    clc ; 2 cyc
    .repeat bytes, byte ; 9 * byte cycles
        lda arg1 + byte
        adc arg2 + byte
        sta dest + byte
    .endrepeat
.endmacro

.macro add16 dest, arg1, arg2
    add 2, dest, arg1, arg2
.endmacro

.macro add32 dest, arg1, arg2
    add 4, dest, arg2, dest
.endmacro

; 2 + 9 * byte cycles
.macro sub bytes, dest, arg1, arg2
    sec ; 2 cyc
    .repeat bytes, byte ; 9 * byte cycles
        lda arg1 + byte
        sbc arg2 + byte
        sta dest + byte
    .endrepeat
.endmacro

.macro sub16 dest, arg1, arg2
    sub 2, dest, arg1, arg2
.endmacro

.macro sub32 dest, arg1, arg2
    sub 4, dest, arg1, arg2
.endmacro

.macro shl bytes, arg
    asl arg
    .repeat bytes-1, i
        rol arg + 1 + i
    .endrepeat
.endmacro

.macro shl16 arg
    shl 2, arg
.endmacro

.macro shl24 arg
    shl 3, arg
.endmacro

.macro shl32 arg
    shl 4, arg
.endmacro

; 6 * bytes cycles
.macro copy bytes, dest, arg
    .repeat bytes, byte ; 6 * bytes cycles
        lda arg + byte  ; 3 cyc
        sta dest + byte ; 3 cyc
    .endrepeat
.endmacro

.macro copy16 dest, arg
    copy 2, dest, arg
.endmacro

.macro copy32 dest, arg
    copy 4, dest, arg
.endmacro

; 2 + 8 * byte cycles
.macro neg bytes, arg
    sec ; 2 cyc
    .repeat bytes, byte ; 8 * byte cycles
        lda #00         ; 2 cyc
        sbc arg + byte  ; 3 cyc
        sta arg + byte  ; 3 cyc
    .endrepeat
.endmacro

; 18 cycles
.macro neg16 arg
    neg 2, arg
.endmacro

; 34 cycles
.macro neg32 arg
    neg 4, arg
.endmacro

.macro extend_8_16 dest, src
    ; clobbers A, X
    ; 13-15 cycles
    .local positive
    .local negative
    ldx #0       ; 2 cyc
    lda src      ; 3 cyc
    sta dest     ; 3 cyc
    bpl positive ; 2 cyc
negative:
    dex          ; 2 cyc
positive:
    stx dest + 1 ; 3 cyc
.endmacro

; inner loop for imul16
; bitnum < 8: 25 or 41 cycles
; bitnum >= 8: 30 or 46 cycles
.macro bitmul16 arg1, arg2, result, bitnum
    .local zero
    .local one
    .local next

    ; does 16-bit adds
    ; arg1 and arg2 are treated as unsigned
    ; negative signed inputs must be flipped first

    ; 7 cycles up to the branch

    ; check if arg1 has 0 or 1 bit in this place
    ; 5 cycles either way
    .if bitnum < 8
        lda arg1                 ; 3 cyc
        and #(1 << bitnum)       ; 2 cyc
    .else
        lda arg1 + 1             ; 3 cyc
        and #(1 << (bitnum - 8)) ; 2 cyc
    .endif
    bne one ; 2 cyc

zero: ; 18 cyc, 23 cyc
    lsr result + 3 ; 5 cyc
    jmp next       ; 3 cyc

one: ; 32 cyc, 37 cyc
    ; 16-bit add on the top bits
    clc            ; 2 cyc
    lda result + 2 ; 3 cyc
    adc arg2       ; 3 cyc
    sta result + 2 ; 3 cyc
    lda result + 3 ; 3 cyc
    adc arg2 + 1   ; 3 cyc
    ror a          ; 2 cyc - get a jump on the shift
    sta result + 3 ; 3 cyc
next:
    ror result + 2 ; 5 cyc
    ror result + 1 ; 5 cyc
    .if bitnum >= 8
        ; we can save 5 cycles * 8 bits = 40 cycles total by skipping this byte
        ; when it's all uninitialized data
        ror result ; 5 cyc
    .endif


.endmacro

; 5 to 25 cycles
.macro check_sign arg
    ; Check sign bit and flip argument to postive,
    ; keeping a count of sign bits in the X register.
    .local positive
    lda arg + 1   ; 3 cyc
    bpl positive  ; 2 cyc
    neg16 arg     ; 18 cyc
    inx           ; 2 cyc
positive:
.endmacro

; 518 - 828 cyc
.macro imul16 dest, arg1, arg2
    copy16 FR0, arg1  ; 12 cyc
    copy16 FR1, arg2  ; 12 cyc
    jsr imul16_func   ; 470-780 cyc
    copy32 dest, FR2  ; 24 cyc
.endmacro

.macro imul16_round dest, arg1, arg2
    copy16 FR0, arg1  ; 12 cyc
    copy16 FR1, arg2  ; 12 cyc
    jsr imul16_func   ; 470-780 cyc
    round16 FR2       ; 5-28 cyc
    copy16 dest, FR2 + 2  ; 12 cyc
.endmacro

; min 470 cycles
; max 780 cycles
.proc imul16_func
    arg1 = FR0   ; 16-bit arg (clobbered)
    arg2 = FR1   ; 16-bit arg (clobbered)
    result = FR2 ; 32-bit result

    ldx #0          ; 2 cyc
    ; counts the number of sign bits in X
    check_sign arg1 ; 5 to 25 cyc
    check_sign arg2 ; 5 to 25 cyc
    
    ; zero out the 32-bit temp's top 16 bits
    lda #0          ; 2 cyc
    sta result + 2  ; 3 cyc
    sta result + 3  ; 3 cyc
    ; the bottom two bytes will get cleared by the shifts

    ; unrolled loop for maximum speed, at the cost
    ; of a larger routine
    ; 440 to 696 cycles
    .repeat 16, bitnum
        ; bitnum < 8: 25 or 41 cycles
        ; bitnum >= 8: 30 or 46 cycles
        bitmul16 arg1, arg2, result, bitnum
    .endrepeat

    ; In case of mixed input signs, return a negative result.
    cpx #1              ; 2 cyc
    bne positive_result ; 2 cyc
    neg32 result        ; 34 cyc
positive_result:

    rts ; 6 cyc
.endproc

.macro round16 arg
    ; Round top 16 bits of 32-bit fixed-point number in-place
    .local zero
    .local one
    .local positive
    .local negative
    .local neg2
    .local next

    ; no round            - 5 cycles
    ; round pos, no carry - 17
    ; round pos, carry    - 22
    ; round neg, no carry - 23
    ; round neg, carry    - 28
    ; average = 5 / 2 + (17 + 22 + 23 + 28) / 8
    ;         = 5 / 2 + 90 / 8
    ;         = 2.5 + 11.25 = 13.75 cycles average on evenly distributed input

    lda arg + 1  ; 3 cyc
    bpl zero     ; 2 cyc

one:
    ; check sign bit
    lda arg + 3  ; 3 cyc
    bpl positive ; 2 cyc

negative:
    lda arg + 2  ; 3 cyc
    beq neg2     ; 2 cyc

    dec arg + 2  ; 5 cyc
    jmp next     ; 3 cyc

neg2:
    dec arg + 2  ; 5 cyc
    dec arg + 3  ; 5 cyc
    jmp next     ; 3 cyc

positive:
    inc arg + 2  ; 5 cyc
    beq next     ; 2 cyc
    inc arg + 3  ; 5 cyc

zero:
next:

.endmacro

.macro shift_round_16 arg, shift
    .repeat shift
        shl32 arg
    .endrepeat
    round16 arg
.endmacro

.proc mandelbrot
    ; input:
    ; cx: position scaled to 4.12 fixed point - -8..+7.9
    ; cy: position scaled to 4.12
    ;
    ; output:
    ; iter: iteration count at escape or 0

    ; zx = 0
    ; zy = 0
    ; zx_2 = 0
    ; zy_2 = 0
    ; zx_zy = 0
    ; dist = 0
    ; iter = 0
    lda #00
    ldx #(iter - zx + 1)
initloop:
    sta zx - 1,x
    dex
    bne initloop

loop:
    ; iter++ & max-iters break
    inc iter
    bne keep_going
    rts
keep_going:

    ; 4.12: (-8 .. +7.9)
    ; zx = zx_2  - zy_2  + cx
    sub16 zx, zx_2 + 2, zy_2 + 2
    add16 zx, zx, cx

    ; zy = zx_zy + zx_zy + cy
    add16 zy, zx_zy + 2, zx_zy + 2
    add16 zy, zy, cy

    ; 8.24: (-128 .. +127.9)
    ; zx_2 = zx * zx
    imul16 zx_2, zx, zx

    ; zy_2 = zy * zy
    imul16 zy_2, zy, zy

    ; zx_zy = zx * zy
    imul16 zx_zy, zx, zy

    ; dist = zx_2 + zy_2
    add32 dist, zx_2, zy_2

    ; if dist >= 4 break, else continue iterating
    lda dist + 3
    cmp #4
    bmi still_in
    rts
still_in:

    shift_round_16 zx_2, 4
    shift_round_16 zy_2, 4
    shift_round_16 zx_zy, 4

    ; if may be in the lake, look for looping output with a small buffer
    ; as an optimization vs running to max iters
    jmp loop

.endproc

.macro zoom_factor dest, src, zoom, aspect
    .local cont
    .local enough

    ; cx = (sx << (8 - zoom))
    copy16 dest, src
    ldx zoom
cont:
    cpx #8
    beq enough
    shl16 dest
    inx
    jmp cont
enough:

    ; cy = cy * (3 / 4)
    ; cx = cx * (5 / 4)
    imul16_round dest, dest, aspect
.endmacro

.proc pset
    ; screen coords in signed sx,sy
    ; iter holds the target to use
    ; @todo implement

    ; iter -> color
    ldx iter
    lda color_map,x
    sta pixel_color
    lda #(255 - 3)
    sta pixel_mask

    ; sy -> line base address in temp
    lda sy
    bpl positive

negative:
    ; temp1 = top half
    lda #.lobyte(framebuffer_top + stride * half_height)
    sta pixel_ptr
    lda #.hibyte(framebuffer_top + stride * half_height)
    sta pixel_ptr + 1
    jmp point

positive:

    lda #.lobyte(framebuffer_bottom)
    sta pixel_ptr
    lda #.hibyte(framebuffer_bottom)
    sta pixel_ptr + 1

point:

    ; pixel_ptr += sy * stride
    ;    temp * 40
    ; =  temp * 32  +  temp * 8
    ; = (temp << 5) + (temp << 3)
    copy16 temp, sy
    shl16 temp
    shl16 temp
    shl16 temp
    add16 pixel_ptr, pixel_ptr, temp
    shl16 temp
    shl16 temp
    add16 pixel_ptr, pixel_ptr, temp

    ; Ok so temp1 points to the start of the line, which is 40 bytes.
    ; Get the byte and bit offsets
    lda sx
    clc
    adc #half_width
    sta temp

    ; pixel_shift = temp & 3
    ; pixel_color <<= pixel_shift (shifting in zeros)
    ; pixel_mask <<= pixel_shift (shifting in ones)
    and #3
    sta pixel_shift
    lda #3
    sec
    sbc pixel_shift
    tax
shift_loop:
    beq shift_done
    asl pixel_color
    asl pixel_color
    sec
    rol pixel_mask
    sec
    rol pixel_mask
    dex
    jmp shift_loop
shift_done:

    ; pixel_offset = temp >> 2
    lda temp
    lsr a
    lsr a
    sta pixel_offset
    tay

    ; read, mask, or, write
    lda (pixel_ptr),y
    and pixel_mask
    ora pixel_color
    sta (pixel_ptr),y

    rts
.endproc

.macro draw_text col, len, cstr
    ; clobbers A, X
    .local loop
    .local done
    ldx #0
loop:
    cpx #len
    beq done
    ldy cstr,x
    lda char_map,y
    sta textbuffer + col,x
    inx
    jmp loop
done:
.endmacro

.proc start

    ; ox = 0; oy = 0; zoom = 0
    lda #0
    sta ox
    sta ox + 1
    sta oy
    sta oy + 1
    sta zoom

    ; Disable display DMA
    sta DMACTL

    ; zero the range from framebuffer_top to framebuffer_end
    lda #.lobyte(framebuffer_top)
    sta temp
    lda #.hibyte(framebuffer_top)
    sta temp + 1

zero_page_loop:
    lda #0
    ldy #0
zero_byte_loop:
    sta (temp),y
    iny
    bne zero_byte_loop

    inc temp + 1
    lda temp + 1
    cmp #.hibyte(framebuffer_end)
    bne zero_page_loop

    ; Copy the display list into properly aligned memory
    ; Can't cross 1024-byte boundaries :D
    ldx #0
copy_byte_loop:
    lda display_list_start,x
    sta display_list,x
    inx
    cpx #display_list_len
    bne copy_byte_loop

    ; Set up the display list
    lda #.lobyte(display_list)
    sta DLISTL ; actual register
    sta SDLSTL ; shadow register the OS will copy in
    lda #.hibyte(display_list)
    sta DLISTH ; actual register
    sta SDLSTH ; shadow register the OS will copy in

    ; Status bar
    draw_text 0, str_self_len, str_self
    draw_text 40 - str_run_len, str_run_len, str_run

    ; Re-enable display DMA
    lda #$22
    sta DMACTL

main_loop:
    ; sy = -92 .. 91
    lda #(256-half_height)
    sta sy
    lda #(256-1)
    sta sy + 1

loop_sy:
    ; sx = -80 .. 79
    lda #(256-half_width)
    sta sx
    lda #(256-1)
    sta sx + 1

loop_sx:
    zoom_factor cx, sx, zoom, aspect_x
    zoom_factor cy, sy, zoom, aspect_y
    jsr mandelbrot
    jsr pset

    clc
    lda sx
    adc #1
    sta sx
    lda sx + 1
    adc #0
    sta sx + 1

    lda sx
    cmp #half_width
    beq loop_sx_done
    jmp loop_sx

loop_sx_done:

    clc
    lda sy
    adc #1
    sta sy
    lda sy + 1
    adc #0
    sta sy + 1

    lda sy
    cmp #half_height
    beq loop_sy_done
    jmp loop_sy

loop_sy_done:

    draw_text 40 - str_done_len, str_done_len, str_done

loop:
    ; finished
    jmp loop
.endproc