友情提示点击顶部放大镜 可以使用站内搜索 记住我们的地址 www.hainabaike.com
诞生至今已经三十年的超级玛丽游戏,作为史上最畅销的电子游戏,游戏主人公马里奥的形象也为众人所熟知,尤其在七零后和八零后的童年有着不可磨灭的印记。然而全球著名的人工智能研究机构纷纷将目光重新放到这类经典的电视游戏中。
估值四亿美金的Deepmind
被Google以四亿美金收购的Deepmind团队演示其神奇的人工智能,同样一套系统可以在将近四十多款电视游戏中通过不断失措和自我学习,周而复始以致最终打出远远超出人类玩家的分数。
Deepmind神经网络
然而,能玩游戏的人工智能也不只Deepmind一家,今年年初,德国图宾根大学的认知模式小组在网上发布了一段视频,主角马里奥可以接受人类的语言教导,然后感知游戏周围环境并能自行做出正确游戏动作的决定,马里奥虽然能够理解英文语句和命令,并基于庞大的逻辑和语法树反馈回答人类究竟告诉过他什么,以及基于它所学知识做出正确的游戏动作。当与Deepmind相比,德国图宾根大学的人工智能仍需人类干预,自我试错和学习的能力弱了许多。
完全开源只有一千多行代码的AI
这不是关键,最为重要的是,后来有人开发出一套称之为MarI/O完全开源的AI,且代码居然只有一千多行!
MarI/O不像图宾根大学的程序,在它进入游戏甚至不知道游戏的终点是什么样的,相反,只是设定了几个简单的参数。这套AI有一个“Fitness”值,当马里奥向右移动时值增加,反之减少,因为它知道“Fitness”值是有好处的,一旦它发现向右移动数值能增加,这就会诱使它一直向右移动。
实际进化过程中,MarI/O并不会进行预测以改变其行动。通过进行不同的尝试,而不是做其“应该”做的事情,这样每次都会产生新的点子。当一个点子成功后,就会被记住,反之则被作废。就这样,超级马里奥在经历了34尝试后,完全通关了!当然,如果重新运行的话,这套AI机会肯定可以找到一条不同但不会更加成功的线路。
这种学习方式称之为神经网络进化拓扑结构(NeuroEvolution of Augmenting Topologies,简称NEAT),虽然这并不是一项新技术,但是在这里,作者却将其使用的非常高效。在一千多行Lua代码下,即实现了与估值四亿美金Deepmind类似的效果,不可不谓十分之神奇。当然,这仅仅是一个不错的演示,对于机器学习如果想要挑战一个更为强大的算法还有很长的路需要走。
对于想学习人工智能,将其算法或精神应用于现有的工作学习中,可以下载其源码本机模拟测试。
-- MarI/O by SethBling -- Feel free to use this code, but please do not redistribute it. -- Intended for use with the BizHawk emulator and Super Mario World or Super Mario Bros. ROM. if gameinfo.getromname() == "Super Mario World (USA)" then Filename = "DP1.state" ButtonNames = { "A", "B", "X", "Y", "Up", "Down", "Left", "Right", } elseif gameinfo.getromname() == "Super Mario Bros." then Filename = "SMB1-1.state" ButtonNames = { "A", "B", "Up", "Down", "Left", "Right", } end BoxRadius = 6 InputSize = (BoxRadius*2+1)*(BoxRadius*2+1) Inputs = InputSize+1 Outputs = #ButtonNames Population = 300 DeltaDisjoint = 2.0 DeltaWeights = 0.4 DeltaThreshold = 1.0 StaleSpecies = 15 MutateConnectionsChance = 0.25 PerturbChance = 0.90 CrossoverChance = 0.75 LinkMutationChance = 2.0 NodeMutationChance = 0.50 BiasMutationChance = 0.40 StepSize = 0.1 DisableMutationChance = 0.4 EnableMutationChance = 0.2 TimeoutConstant = 20 MaxNodes = 1000000 function getPositions() if gameinfo.getromname() == "Super Mario World (USA)" then marioX = memory.read_s16_le(0x94) marioY = memory.read_s16_le(0x96) local layer1x = memory.read_s16_le(0x1A); local layer1y = memory.read_s16_le(0x1C); screenX = marioX-layer1x screenY = marioY-layer1y elseif gameinfo.getromname() == "Super Mario Bros." then marioX = memory.readbyte(0x6D) * 0x100 + memory.readbyte(0x86) marioY = memory.readbyte(0x03B8)+16 screenX = memory.readbyte(0x03AD) screenY = memory.readbyte(0x03B8) end end function getTile(dx, dy) if gameinfo.getromname() == "Super Mario World (USA)" then x = math.floor((marioX+dx+8)/16) y = math.floor((marioY+dy)/16) return memory.readbyte(0x1C800 + math.floor(x/0x10)*0x1B0 + y*0x10 + x%0x10) elseif gameinfo.getromname() == "Super Mario Bros." then local x = marioX + dx + 8 local y = marioY + dy - 16 local page = math.floor(x/256)%2 local subx = math.floor((x%256)/16) local suby = math.floor((y - 32)/16) local addr = 0x500 + page*13*16+suby*16+subx if suby >= 13 or suby < 0 then return 0 end if memory.readbyte(addr) ~= 0 then return 1 else return 0 end end end function getSprites() if gameinfo.getromname() == "Super Mario World (USA)" then local sprites = {} for slot=0,11 do local status = memory.readbyte(0x14C8+slot) if status ~= 0 then spritex = memory.readbyte(0xE4+slot) + memory.readbyte(0x14E0+slot)*256 spritey = memory.readbyte(0xD8+slot) + memory.readbyte(0x14D4+slot)*256 sprites[#sprites+1] = {["x"]=spritex, ["y"]=spritey} end end return sprites elseif gameinfo.getromname() == "Super Mario Bros." then local sprites = {} for slot=0,4 do local enemy = memory.readbyte(0xF+slot) if enemy ~= 0 then local ex = memory.readbyte(0x6E + slot)*0x100 + memory.readbyte(0x87+slot) local ey = memory.readbyte(0xCF + slot)+24 sprites[#sprites+1] = {["x"]=ex,["y"]=ey} end end return sprites end end function getExtendedSprites() if gameinfo.getromname() == "Super Mario World (USA)" then local extended = {} for slot=0,11 do local number = memory.readbyte(0x170B+slot) if number ~= 0 then spritex = memory.readbyte(0x171F+slot) + memory.readbyte(0x1733+slot)*256 spritey = memory.readbyte(0x1715+slot) + memory.readbyte(0x1729+slot)*256 extended[#extended+1] = {["x"]=spritex, ["y"]=spritey} end end return extended elseif gameinfo.getromname() == "Super Mario Bros." then return {} end end function getInputs() getPositions() sprites = getSprites() extended = getExtendedSprites() local inputs = {} for dy=-BoxRadius*16,BoxRadius*16,16 do for dx=-BoxRadius*16,BoxRadius*16,16 do inputs[#inputs+1] = 0 tile = getTile(dx, dy) if tile == 1 and marioY+dy < 0x1B0 then inputs[#inputs] = 1 end for i = 1,#sprites do distx = math.abs(sprites[i]["x"] - (marioX+dx)) disty = math.abs(sprites[i]["y"] - (marioY+dy)) if distx <= 8 and disty <= 8 then inputs[#inputs] = -1 end end for i = 1,#extended do distx = math.abs(extended[i]["x"] - (marioX+dx)) disty = math.abs(extended[i]["y"] - (marioY+dy)) if distx < 8 and disty < 8 then inputs[#inputs] = -1 end end end end --mariovx = memory.read_s8(0x7B) --mariovy = memory.read_s8(0x7D) return inputs end function sigmoid(x) return 2/(1+math.exp(-4.9*x))-1 end function newInnovation() pool.innovation = pool.innovation + 1 return pool.innovation end function newPool() local pool = {} pool.species = {} pool.generation = 0 pool.innovation = Outputs pool.currentSpecies = 1 pool.currentGenome = 1 pool.currentFrame = 0 pool.maxFitness = 0 return pool end function newSpecies() local species = {} species.topFitness = 0 species.staleness = 0 species.genomes = {} species.averageFitness = 0 return species end function newGenome() local genome = {} genome.genes = {} genome.fitness = 0 genome.adjustedFitness = 0 genome.network = {} genome.maxneuron = 0 genome.globalRank = 0 genome.mutationRates = {} genome.mutationRates["connections"] = MutateConnectionsChance genome.mutationRates["link"] = LinkMutationChance genome.mutationRates["bias"] = BiasMutationChance genome.mutationRates["node"] = NodeMutationChance genome.mutationRates["enable"] = EnableMutationChance genome.mutationRates["disable"] = DisableMutationChance genome.mutationRates["step"] = StepSize return genome end function copyGenome(genome) local genome2 = newGenome() for g=1,#genome.genes do table.insert(genome2.genes, copyGene(genome.genes[g])) end genome2.maxneuron = genome.maxneuron genome2.mutationRates["connections"] = genome.mutationRates["connections"] genome2.mutationRates["link"] = genome.mutationRates["link"] genome2.mutationRates["bias"] = genome.mutationRates["bias"] genome2.mutationRates["node"] = genome.mutationRates["node"] genome2.mutationRates["enable"] = genome.mutationRates["enable"] genome2.mutationRates["disable"] = genome.mutationRates["disable"] return genome2 end function basicGenome() local genome = newGenome() local innovation = 1 genome.maxneuron = Inputs mutate(genome) return genome end function newGene() local gene = {} gene.into = 0 gene.out = 0 gene.weight = 0.0 gene.enabled = true gene.innovation = 0 return gene end function copyGene(gene) local gene2 = newGene() gene2.into = gene.into gene2.out = gene.out gene2.weight = gene.weight gene2.enabled = gene.enabled gene2.innovation = gene.innovation return gene2 end function newNeuron() local neuron = {} neuron.incoming = {} neuron.value = 0.0 return neuron end function generateNetwork(genome) local network = {} network.neurons = {} for i=1,Inputs do network.neurons[i] = newNeuron() end for o=1,Outputs do network.neurons[MaxNodes+o] = newNeuron() end table.sort(genome.genes, function (a,b) return (a.out < b.out) end) for i=1,#genome.genes do local gene = genome.genes[i] if gene.enabled then if network.neurons[gene.out] == nil then network.neurons[gene.out] = newNeuron() end local neuron = network.neurons[gene.out] table.insert(neuron.incoming, gene) if network.neurons[gene.into] == nil then network.neurons[gene.into] = newNeuron() end end end genome.network = network end function evaluateNetwork(network, inputs) table.insert(inputs, 1) if #inputs ~= Inputs then console.writeline("Incorrect number of neural network inputs.") return {} end for i=1,Inputs do network.neurons[i].value = inputs[i] end for _,neuron in pairs(network.neurons) do local sum = 0 for j = 1,#neuron.incoming do local incoming = neuron.incoming[j] local other = network.neurons[incoming.into] sum = sum + incoming.weight * other.value end if #neuron.incoming > 0 then neuron.value = sigmoid(sum) end end local outputs = {} for o=1,Outputs do local button = "P1 " .. ButtonNames[o] if network.neurons[MaxNodes+o].value > 0 then outputs[button] = true else outputs[button] = false end end return outputs end function crossover(g1, g2) -- Make sure g1 is the higher fitness genome if g2.fitness > g1.fitness then tempg = g1 g1 = g2 g2 = tempg end local child = newGenome() local innovations2 = {} for i=1,#g2.genes do local gene = g2.genes[i] innovations2[gene.innovation] = gene end for i=1,#g1.genes do local gene1 = g1.genes[i] local gene2 = innovations2[gene1.innovation] if gene2 ~= nil and math.random(2) == 1 and gene2.enabled then table.insert(child.genes, copyGene(gene2)) else table.insert(child.genes, copyGene(gene1)) end end child.maxneuron = math.max(g1.maxneuron,g2.maxneuron) for mutation,rate in pairs(g1.mutationRates) do child.mutationRates[mutation] = rate end return child end function randomNeuron(genes, nonInput) local neurons = {} if not nonInput then for i=1,Inputs do neurons[i] = true end end for o=1,Outputs do neurons[MaxNodes+o] = true end for i=1,#genes do if (not nonInput) or genes[i].into > Inputs then neurons[genes[i].into] = true end if (not nonInput) or genes[i].out > Inputs then neurons[genes[i].out] = true end end local count = 0 for _,_ in pairs(neurons) do count = count + 1 end local n = math.random(1, count) for k,v in pairs(neurons) do n = n-1 if n == 0 then return k end end return 0 end function containsLink(genes, link) for i=1,#genes do local gene = genes[i] if gene.into == link.into and gene.out == link.out then return true end end end function pointMutate(genome) local step = genome.mutationRates["step"] for i=1,#genome.genes do local gene = genome.genes[i] if math.random() < PerturbChance then gene.weight = gene.weight + math.random() * step*2 - step else gene.weight = math.random()*4-2 end end end function linkMutate(genome, forceBias) local neuron1 = randomNeuron(genome.genes, false) local neuron2 = randomNeuron(genome.genes, true) local newLink = newGene() if neuron1 <= Inputs and neuron2 <= Inputs then --Both input nodes return end if neuron2 <= Inputs then -- Swap output and input local temp = neuron1 neuron1 = neuron2 neuron2 = temp end newLink.into = neuron1 newLink.out = neuron2 if forceBias then newLink.into = Inputs end if containsLink(genome.genes, newLink) then return end newLink.innovation = newInnovation() newLink.weight = math.random()*4-2 table.insert(genome.genes, newLink) end function nodeMutate(genome) if #genome.genes == 0 then return end genome.maxneuron = genome.maxneuron + 1 local gene = genome.genes[math.random(1,#genome.genes)] if not gene.enabled then return end gene.enabled = false local gene1 = copyGene(gene) gene1.out = genome.maxneuron gene1.weight = 1.0 gene1.innovation = newInnovation() gene1.enabled = true table.insert(genome.genes, gene1) local gene2 = copyGene(gene) gene2.into = genome.maxneuron gene2.innovation = newInnovation() gene2.enabled = true table.insert(genome.genes, gene2) end function enableDisableMutate(genome, enable) local candidates = {} for _,gene in pairs(genome.genes) do if gene.enabled == not enable then table.insert(candidates, gene) end end if #candidates == 0 then return end local gene = candidates[math.random(1,#candidates)] gene.enabled = not gene.enabled end function mutate(genome) for mutation,rate in pairs(genome.mutationRates) do if math.random(1,2) == 1 then genome.mutationRates[mutation] = 0.95*rate else genome.mutationRates[mutation] = 1.05263*rate end end if math.random() < genome.mutationRates["connections"] then pointMutate(genome) end local p = genome.mutationRates["link"] while p > 0 do if math.random() < p then linkMutate(genome, false) end p = p - 1 end p = genome.mutationRates["bias"] while p > 0 do if math.random() < p then linkMutate(genome, true) end p = p - 1 end p = genome.mutationRates["node"] while p > 0 do if math.random() < p then nodeMutate(genome) end p = p - 1 end p = genome.mutationRates["enable"] while p > 0 do if math.random() < p then enableDisableMutate(genome, true) end p = p - 1 end p = genome.mutationRates["disable"] while p > 0 do if math.random() < p then enableDisableMutate(genome, false) end p = p - 1 end end function disjoint(genes1, genes2) local i1 = {} for i = 1,#genes1 do local gene = genes1[i] i1[gene.innovation] = true end local i2 = {} for i = 1,#genes2 do local gene = genes2[i] i2[gene.innovation] = true end local disjointGenes = 0 for i = 1,#genes1 do local gene = genes1[i] if not i2[gene.innovation] then disjointGenes = disjointGenes+1 end end for i = 1,#genes2 do local gene = genes2[i] if not i1[gene.innovation] then disjointGenes = disjointGenes+1 end end local n = math.max(#genes1, #genes2) return disjointGenes / n end function weights(genes1, genes2) local i2 = {} for i = 1,#genes2 do local gene = genes2[i] i2[gene.innovation] = gene end local sum = 0 local coincident = 0 for i = 1,#genes1 do local gene = genes1[i] if i2[gene.innovation] ~= nil then local gene2 = i2[gene.innovation] sum = sum + math.abs(gene.weight - gene2.weight) coincident = coincident + 1 end end return sum / coincident end function sameSpecies(genome1, genome2) local dd = DeltaDisjoint*disjoint(genome1.genes, genome2.genes) local dw = DeltaWeights*weights(genome1.genes, genome2.genes) return dd + dw < DeltaThreshold end function rankGlobally() local global = {} for s = 1,#pool.species do local species = pool.species[s] for g = 1,#species.genomes do table.insert(global, species.genomes[g]) end end table.sort(global, function (a,b) return (a.fitness < b.fitness) end) for g=1,#global do global[g].globalRank = g end end function calculateAverageFitness(species) local total = 0 for g=1,#species.genomes do local genome = species.genomes[g] total = total + genome.globalRank end species.averageFitness = total / #species.genomes end function totalAverageFitness() local total = 0 for s = 1,#pool.species do local species = pool.species[s] total = total + species.averageFitness end return total end function cullSpecies(cutToOne) for s = 1,#pool.species do local species = pool.species[s] table.sort(species.genomes, function (a,b) return (a.fitness > b.fitness) end) local remaining = math.ceil(#species.genomes/2) if cutToOne then remaining = 1 end while #species.genomes > remaining do table.remove(species.genomes) end end end function breedChild(species) local child = {} if math.random() < CrossoverChance then g1 = species.genomes[math.random(1, #species.genomes)] g2 = species.genomes[math.random(1, #species.genomes)] child = crossover(g1, g2) else g = species.genomes[math.random(1, #species.genomes)] child = copyGenome(g) end mutate(child) return child end function removeStaleSpecies() local survived = {} for s = 1,#pool.species do local species = pool.species[s] table.sort(species.genomes, function (a,b) return (a.fitness > b.fitness) end) if species.genomes[1].fitness > species.topFitness then species.topFitness = species.genomes[1].fitness species.staleness = 0 else species.staleness = species.staleness + 1 end if species.staleness < StaleSpecies or species.topFitness >= pool.maxFitness then table.insert(survived, species) end end pool.species = survived end function removeWeakSpecies() local survived = {} local sum = totalAverageFitness() for s = 1,#pool.species do local species = pool.species[s] breed = math.floor(species.averageFitness / sum * Population) if breed >= 1 then table.insert(survived, species) end end pool.species = survived end function addToSpecies(child) local foundSpecies = false for s=1,#pool.species do local species = pool.species[s] if not foundSpecies and sameSpecies(child, species.genomes[1]) then table.insert(species.genomes, child) foundSpecies = true end end if not foundSpecies then local childSpecies = newSpecies() table.insert(childSpecies.genomes, child) table.insert(pool.species, childSpecies) end end function newGeneration() cullSpecies(false) -- Cull the bottom half of each species rankGlobally() removeStaleSpecies() rankGlobally() for s = 1,#pool.species do local species = pool.species[s] calculateAverageFitness(species) end removeWeakSpecies() local sum = totalAverageFitness() local children = {} for s = 1,#pool.species do local species = pool.species[s] breed = math.floor(species.averageFitness / sum * Population) - 1 for i=1,breed do table.insert(children, breedChild(species)) end end cullSpecies(true) -- Cull all but the top member of each species while #children + #pool.species < Population do local species = pool.species[math.random(1, #pool.species)] table.insert(children, breedChild(species)) end for c=1,#children do local child = children addToSpecies(child) end pool.generation = pool.generation + 1 writeFile("backup." .. pool.generation .. "." .. forms.gettext(saveLoadFile)) end function initializePool() pool = newPool() for i=1,Population do basic = basicGenome() addToSpecies(basic) end initializeRun() end function clearJoypad() controller = {} for b = 1,#ButtonNames do controller["P1 " .. ButtonNames[b]] = false end joypad.set(controller) end function initializeRun() savestate.load(Filename); rightmost = 0 pool.currentFrame = 0 timeout = TimeoutConstant clearJoypad() local species = pool.species[pool.currentSpecies] local genome = species.genomes[pool.currentGenome] generateNetwork(genome) evaluateCurrent() end function evaluateCurrent() local species = pool.species[pool.currentSpecies] local genome = species.genomes[pool.currentGenome] inputs = getInputs() controller = evaluateNetwork(genome.network, inputs) if controller["P1 Left"] and controller["P1 Right"] then controller["P1 Left"] = false controller["P1 Right"] = false end if controller["P1 Up"] and controller["P1 Down"] then controller["P1 Up"] = false controller["P1 Down"] = false end joypad.set(controller) end if pool == nil then initializePool() end function nextGenome() pool.currentGenome = pool.currentGenome + 1 if pool.currentGenome > #pool.species[pool.currentSpecies].genomes then pool.currentGenome = 1 pool.currentSpecies = pool.currentSpecies+1 if pool.currentSpecies > #pool.species then newGeneration() pool.currentSpecies = 1 end end end function fitnessAlreadyMeasured() local species = pool.species[pool.currentSpecies] local genome = species.genomes[pool.currentGenome] return genome.fitness ~= 0 end function displayGenome(genome) local network = genome.network local cells = {} local i = 1 local cell = {} for dy=-BoxRadius,BoxRadius do for dx=-BoxRadius,BoxRadius do cell = {} cell.x = 50+5*dx cell.y = 70+5*dy cell.value = network.neurons[i].value cells[i] = cell i = i + 1 end end local biasCell = {} biasCell.x = 80 biasCell.y = 110 biasCell.value = network.neurons[Inputs].value cells[Inputs] = biasCell for o = 1,Outputs do cell = {} cell.x = 220 cell.y = 30 + 8 * o cell.value = network.neurons[MaxNodes + o].value cells[MaxNodes+o] = cell local color if cell.value > 0 then color = 0xFF0000FF else color = 0xFF000000 end gui.drawText(223, 24+8*o, ButtonNames[o], color, 9) end for n,neuron in pairs(network.neurons) do cell = {} if n > Inputs and n <= MaxNodes then cell.x = 140 cell.y = 40 cell.value = neuron.value cells[n] = cell end end for n=1,4 do for _,gene in pairs(genome.genes) do if gene.enabled then local c1 = cells[gene.into] local c2 = cells[gene.out] if gene.into > Inputs and gene.into <= MaxNodes then c1.x = 0.75*c1.x + 0.25*c2.x if c1.x >= c2.x then c1.x = c1.x - 40 end if c1.x < 90 then c1.x = 90 end if c1.x > 220 then c1.x = 220 end c1.y = 0.75*c1.y + 0.25*c2.y end if gene.out > Inputs and gene.out <= MaxNodes then c2.x = 0.25*c1.x + 0.75*c2.x if c1.x >= c2.x then c2.x = c2.x + 40 end if c2.x < 90 then c2.x = 90 end if c2.x > 220 then c2.x = 220 end c2.y = 0.25*c1.y + 0.75*c2.y end end end end gui.drawBox(50-BoxRadius*5-3,70-BoxRadius*5-3,50+BoxRadius*5+2,70+BoxRadius*5+2,0xFF000000, 0x80808080) for n,cell in pairs(cells) do if n > Inputs or cell.value ~= 0 then local color = math.floor((cell.value+1)/2*256) if color > 255 then color = 255 end if color < 0 then color = 0 end local opacity = 0xFF000000 if cell.value == 0 then opacity = 0x50000000 end color = opacity + color*0x10000 + color*0x100 + color gui.drawBox(cell.x-2,cell.y-2,cell.x+2,cell.y+2,opacity,color) end end for _,gene in pairs(genome.genes) do if gene.enabled then local c1 = cells[gene.into] local c2 = cells[gene.out] local opacity = 0xA0000000 if c1.value == 0 then opacity = 0x20000000 end local color = 0x80-math.floor(math.abs(sigmoid(gene.weight))*0x80) if gene.weight > 0 then color = opacity + 0x8000 + 0x10000*color else color = opacity + 0x800000 + 0x100*color end gui.drawLine(c1.x+1, c1.y, c2.x-3, c2.y, color) end end gui.drawBox(49,71,51,78,0x00000000,0x80FF0000) if forms.ischecked(showMutationRates) then local pos = 100 for mutation,rate in pairs(genome.mutationRates) do gui.drawText(100, pos, mutation .. ": " .. rate, 0xFF000000, 10) pos = pos + 8 end end end function writeFile(filename) local file = io.open(filename, "w") file:write(pool.generation .. " ") file:write(pool.maxFitness .. " ") file:write(#pool.species .. " ") for n,species in pairs(pool.species) do file:write(species.topFitness .. " ") file:write(species.staleness .. " ") file:write(#species.genomes .. " ") for m,genome in pairs(species.genomes) do file:write(genome.fitness .. " ") file:write(genome.maxneuron .. " ") for mutation,rate in pairs(genome.mutationRates) do file:write(mutation .. " ") file:write(rate .. " ") end file:write("done ") file:write(#genome.genes .. " ") for l,gene in pairs(genome.genes) do file:write(gene.into .. " ") file:write(gene.out .. " ") file:write(gene.weight .. " ") file:write(gene.innovation .. " ") if(gene.enabled) then file:write("1 ") else file:write("0 ") end end end end file:close() end function savePool() local filename = forms.gettext(saveLoadFile) writeFile(filename) end function loadFile(filename) local file = io.open(filename, "r") pool = newPool() pool.generation = file:read("*number") pool.maxFitness = file:read("*number") forms.settext(maxFitnessLabel, "Max Fitness: " .. math.floor(pool.maxFitness)) local numSpecies = file:read("*number") for s=1,numSpecies do local species = newSpecies() table.insert(pool.species, species) species.topFitness = file:read("*number") species.staleness = file:read("*number") local numGenomes = file:read("*number") for g=1,numGenomes do local genome = newGenome() table.insert(species.genomes, genome) genome.fitness = file:read("*number") genome.maxneuron = file:read("*number") local line = file:read("*line") while line ~= "done" do genome.mutationRates[line] = file:read("*number") line = file:read("*line") end local numGenes = file:read("*number") for n=1,numGenes do local gene = newGene() table.insert(genome.genes, gene) local enabled gene.into, gene.out, gene.weight, gene.innovation, enabled = file:read("*number", "*number", "*number", "*number", "*number") if enabled == 0 then gene.enabled = false else gene.enabled = true end end end end file:close() while fitnessAlreadyMeasured() do nextGenome() end initializeRun() pool.currentFrame = pool.currentFrame + 1 end function loadPool() local filename = forms.gettext(saveLoadFile) loadFile(filename) end function playTop() local maxfitness = 0 local maxs, maxg for s,species in pairs(pool.species) do for g,genome in pairs(species.genomes) do if genome.fitness > maxfitness then maxfitness = genome.fitness maxs = s maxg = g end end end pool.currentSpecies = maxs pool.currentGenome = maxg pool.maxFitness = maxfitness forms.settext(maxFitnessLabel, "Max Fitness: " .. math.floor(pool.maxFitness)) initializeRun() pool.currentFrame = pool.currentFrame + 1 return end function onExit() forms.destroy(form) end writeFile("temp.pool") event.onexit(onExit) form = forms.newform(200, 260, "Fitness") maxFitnessLabel = forms.label(form, "Max Fitness: " .. math.floor(pool.maxFitness), 5, 8) showNetwork = forms.checkbox(form, "Show Map", 5, 30) showMutationRates = forms.checkbox(form, "Show M-Rates", 5, 52) restartButton = forms.button(form, "Restart", initializePool, 5, 77) saveButton = forms.button(form, "Save", savePool, 5, 102) loadButton = forms.button(form, "Load", loadPool, 80, 102) saveLoadFile = forms.textbox(form, Filename .. ".pool", 170, 25, nil, 5, 148) saveLoadLabel = forms.label(form, "Save/Load:", 5, 129) playTopButton = forms.button(form, "Play Top", playTop, 5, 170) hideBanner = forms.checkbox(form, "Hide Banner", 5, 190) while true do local backgroundColor = 0xD0FFFFFF if not forms.ischecked(hideBanner) then gui.drawBox(0, 0, 300, 26, backgroundColor, backgroundColor) end local species = pool.species[pool.currentSpecies] local genome = species.genomes[pool.currentGenome] if forms.ischecked(showNetwork) then displayGenome(genome) end if pool.currentFrame%5 == 0 then evaluateCurrent() end joypad.set(controller) getPositions() if marioX > rightmost then rightmost = marioX timeout = TimeoutConstant end timeout = timeout - 1 local timeoutBonus = pool.currentFrame / 4 if timeout + timeoutBonus <= 0 then local fitness = rightmost - pool.currentFrame / 2 if gameinfo.getromname() == "Super Mario World (USA)" and rightmost > 4816 then fitness = fitness + 1000 end if gameinfo.getromname() == "Super Mario Bros." and rightmost > 3186 then fitness = fitness + 1000 end if fitness == 0 then fitness = -1 end genome.fitness = fitness if fitness > pool.maxFitness then pool.maxFitness = fitness forms.settext(maxFitnessLabel, "Max Fitness: " .. math.floor(pool.maxFitness)) writeFile("backup." .. pool.generation .. "." .. forms.gettext(saveLoadFile)) end console.writeline("Gen " .. pool.generation .. " species " .. pool.currentSpecies .. " genome " .. pool.currentGenome .. " fitness: " .. fitness) pool.currentSpecies = 1 pool.currentGenome = 1 while fitnessAlreadyMeasured() do nextGenome() end initializeRun() end local measured = 0 local total = 0 for _,species in pairs(pool.species) do for _,genome in pairs(species.genomes) do total = total + 1 if genome.fitness ~= 0 then measured = measured + 1 end end end if not forms.ischecked(hideBanner) then gui.drawText(0, 0, "Gen " .. pool.generation .. " species " .. pool.currentSpecies .. " genome " .. pool.currentGenome .. " (" .. math.floor(measured/total*100) .. "%)", 0xFF000000, 11) gui.drawText(0, 12, "Fitness: " .. math.floor(rightmost - (pool.currentFrame) / 2 - (timeout + timeoutBonus)*2/3), 0xFF000000, 11) gui.drawText(100, 12, "Max Fitness: " .. math.floor(pool.maxFitness), 0xFF000000, 11) end pool.currentFrame = pool.currentFrame + 1 emu.frameadvance(); end
评论列表